Skip to content

Commit

Permalink
Feat Add headers to openai responses (sashabaranov#506)
Browse files Browse the repository at this point in the history
* feat: add headers to http response

* chore: add test

* fix: rename to httpHeader
  • Loading branch information
henomis authored Oct 9, 2023
1 parent 533935e commit 8e165dc
Show file tree
Hide file tree
Showing 14 changed files with 107 additions and 2 deletions.
19 changes: 18 additions & 1 deletion audio.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,21 @@ type AudioResponse struct {
Transient bool `json:"transient"`
} `json:"segments"`
Text string `json:"text"`

httpHeader
}

type audioTextResponse struct {
Text string `json:"text"`

httpHeader
}

func (r *audioTextResponse) ToAudioResponse() AudioResponse {
return AudioResponse{
Text: r.Text,
httpHeader: r.httpHeader,
}
}

// CreateTranscription — API call to create a transcription. Returns transcribed text.
Expand Down Expand Up @@ -104,7 +119,9 @@ func (c *Client) callAudioAPI(
if request.HasJSONResponse() {
err = c.sendRequest(req, &response)
} else {
err = c.sendRequest(req, &response.Text)
var textResponse audioTextResponse
err = c.sendRequest(req, &textResponse)
response = textResponse.ToAudioResponse()
}
if err != nil {
return AudioResponse{}, err
Expand Down
2 changes: 2 additions & 0 deletions chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ type ChatCompletionResponse struct {
Model string `json:"model"`
Choices []ChatCompletionChoice `json:"choices"`
Usage Usage `json:"usage"`

httpHeader
}

// CreateChatCompletion — API call to Create a completion for the chat message.
Expand Down
30 changes: 30 additions & 0 deletions chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ import (
"github.com/sashabaranov/go-openai/jsonschema"
)

const (
xCustomHeader = "X-CUSTOM-HEADER"
xCustomHeaderValue = "test"
)

func TestChatCompletionsWrongModel(t *testing.T) {
config := DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
Expand Down Expand Up @@ -68,6 +73,30 @@ func TestChatCompletions(t *testing.T) {
checks.NoError(t, err, "CreateChatCompletion error")
}

// TestCompletions Tests the completions endpoint of the API using the mocked server.
func TestChatCompletionsWithHeaders(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
resp, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{
MaxTokens: 5,
Model: GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Content: "Hello!",
},
},
})
checks.NoError(t, err, "CreateChatCompletion error")

a := resp.Header().Get(xCustomHeader)
_ = a
if resp.Header().Get(xCustomHeader) != xCustomHeaderValue {
t.Errorf("expected header %s to be %s", xCustomHeader, xCustomHeaderValue)
}
}

// TestChatCompletionsFunctions tests including a function call.
func TestChatCompletionsFunctions(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
Expand Down Expand Up @@ -281,6 +310,7 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
TotalTokens: inputTokens + completionTokens,
}
resBytes, _ = json.Marshal(res)
w.Header().Set(xCustomHeader, xCustomHeaderValue)
fmt.Fprintln(w, string(resBytes))
}

Expand Down
20 changes: 19 additions & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,20 @@ type Client struct {
createFormBuilder func(io.Writer) utils.FormBuilder
}

type Response interface {
SetHeader(http.Header)
}

type httpHeader http.Header

func (h *httpHeader) SetHeader(header http.Header) {
*h = httpHeader(header)
}

func (h httpHeader) Header() http.Header {
return http.Header(h)
}

// NewClient creates new OpenAI API client.
func NewClient(authToken string) *Client {
config := DefaultConfig(authToken)
Expand Down Expand Up @@ -82,7 +96,7 @@ func (c *Client) newRequest(ctx context.Context, method, url string, setters ...
return req, nil
}

func (c *Client) sendRequest(req *http.Request, v any) error {
func (c *Client) sendRequest(req *http.Request, v Response) error {
req.Header.Set("Accept", "application/json; charset=utf-8")

// Check whether Content-Type is already set, Upload Files API requires
Expand All @@ -103,6 +117,10 @@ func (c *Client) sendRequest(req *http.Request, v any) error {
return c.handleErrorResp(res)
}

if v != nil {
v.SetHeader(res.Header)
}

return decodeResponse(res.Body, v)
}

Expand Down
2 changes: 2 additions & 0 deletions completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ type CompletionResponse struct {
Model string `json:"model"`
Choices []CompletionChoice `json:"choices"`
Usage Usage `json:"usage"`

httpHeader
}

// CreateCompletion — API call to create a completion. This is the main endpoint of the API. Returns new text as well
Expand Down
2 changes: 2 additions & 0 deletions edits.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ type EditsResponse struct {
Created int64 `json:"created"`
Usage Usage `json:"usage"`
Choices []EditsChoice `json:"choices"`

httpHeader
}

// Edits Perform an API call to the Edits endpoint.
Expand Down
4 changes: 4 additions & 0 deletions embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ type EmbeddingResponse struct {
Data []Embedding `json:"data"`
Model EmbeddingModel `json:"model"`
Usage Usage `json:"usage"`

httpHeader
}

type base64String string
Expand Down Expand Up @@ -182,6 +184,8 @@ type EmbeddingResponseBase64 struct {
Data []Base64Embedding `json:"data"`
Model EmbeddingModel `json:"model"`
Usage Usage `json:"usage"`

httpHeader
}

// ToEmbeddingResponse converts an embeddingResponseBase64 to an EmbeddingResponse.
Expand Down
4 changes: 4 additions & 0 deletions engines.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@ type Engine struct {
Object string `json:"object"`
Owner string `json:"owner"`
Ready bool `json:"ready"`

httpHeader
}

// EnginesList is a list of engines.
type EnginesList struct {
Engines []Engine `json:"data"`

httpHeader
}

// ListEngines Lists the currently available engines, and provides basic
Expand Down
4 changes: 4 additions & 0 deletions files.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,15 @@ type File struct {
Status string `json:"status"`
Purpose string `json:"purpose"`
StatusDetails string `json:"status_details"`

httpHeader
}

// FilesList is a list of files that belong to the user or organization.
type FilesList struct {
Files []File `json:"data"`

httpHeader
}

// CreateFile uploads a jsonl file to GPT3
Expand Down
8 changes: 8 additions & 0 deletions fine_tunes.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ type FineTune struct {
ValidationFiles []File `json:"validation_files"`
TrainingFiles []File `json:"training_files"`
UpdatedAt int64 `json:"updated_at"`

httpHeader
}

// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API.
Expand Down Expand Up @@ -69,6 +71,8 @@ type FineTuneHyperParams struct {
type FineTuneList struct {
Object string `json:"object"`
Data []FineTune `json:"data"`

httpHeader
}

// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API.
Expand All @@ -77,6 +81,8 @@ type FineTuneList struct {
type FineTuneEventList struct {
Object string `json:"object"`
Data []FineTuneEvent `json:"data"`

httpHeader
}

// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API.
Expand All @@ -86,6 +92,8 @@ type FineTuneDeleteResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Deleted bool `json:"deleted"`

httpHeader
}

// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API.
Expand Down
4 changes: 4 additions & 0 deletions fine_tuning_job.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ type FineTuningJob struct {
ValidationFile string `json:"validation_file,omitempty"`
ResultFiles []string `json:"result_files"`
TrainedTokens int `json:"trained_tokens"`

httpHeader
}

type Hyperparameters struct {
Expand All @@ -39,6 +41,8 @@ type FineTuningJobEventList struct {
Object string `json:"object"`
Data []FineTuneEvent `json:"data"`
HasMore bool `json:"has_more"`

httpHeader
}

type FineTuningJobEvent struct {
Expand Down
2 changes: 2 additions & 0 deletions image.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ type ImageRequest struct {
type ImageResponse struct {
Created int64 `json:"created,omitempty"`
Data []ImageResponseDataInner `json:"data,omitempty"`

httpHeader
}

// ImageResponseDataInner represents a response data structure for image API.
Expand Down
6 changes: 6 additions & 0 deletions models.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ type Model struct {
Permission []Permission `json:"permission"`
Root string `json:"root"`
Parent string `json:"parent"`

httpHeader
}

// Permission struct represents an OpenAPI permission.
Expand All @@ -38,11 +40,15 @@ type FineTuneModelDeleteResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Deleted bool `json:"deleted"`

httpHeader
}

// ModelsList is a list of models, including those that belong to the user or organization.
type ModelsList struct {
Models []Model `json:"data"`

httpHeader
}

// ListModels Lists the currently available models,
Expand Down
2 changes: 2 additions & 0 deletions moderation.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ type ModerationResponse struct {
ID string `json:"id"`
Model string `json:"model"`
Results []Result `json:"results"`

httpHeader
}

// Moderations — perform a moderation api call over a string.
Expand Down

0 comments on commit 8e165dc

Please sign in to comment.