From ecdea45b6753592d6e0d39adca7416e77e63d7e4 Mon Sep 17 00:00:00 2001 From: Hoani Bryson Date: Fri, 21 Apr 2023 01:07:04 +1200 Subject: [PATCH] Adds support for audio captioning with Whisper (#267) * Add speech to text example in docs * Add caption formats for audio transcription * Add caption example to README * Address sanity check errors * Add tests for decodeResponse * Use typechecker for audio response format * Decoding response refactors --- README.md | 43 ++++++++++++++++++++++++++++++++++++++++++- audio.go | 29 ++++++++++++++++++++++++++++- audio_test.go | 4 +++- client.go | 24 +++++++++++++++++++----- client_test.go | 37 +++++++++++++++++++++++++++++++++++++ 5 files changed, 129 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 898465cdd..7526ea333 100644 --- a/README.md +++ b/README.md @@ -223,6 +223,47 @@ func main() { ``` +
+Audio Captions + +```go +package main + +import ( + "context" + "fmt" + "os" + + openai "github.com/sashabaranov/go-openai" +) + +func main() { + c := openai.NewClient(os.Getenv("OPENAI_KEY")) + + req := openai.AudioRequest{ + Model: openai.Whisper1, + FilePath: os.Args[1], + Format: openai.AudioResponseFormatSRT, + } + resp, err := c.CreateTranscription(context.Background(), req) + if err != nil { + fmt.Printf("Transcription error: %v\n", err) + return + } + f, err := os.Create(os.Args[1] + ".srt") + if err != nil { + fmt.Printf("Could not open file: %v\n", err) + return + } + defer f.Close() + if _, err := f.WriteString(resp.Text); err != nil { + fmt.Printf("Error writing to file: %v\n", err) + return + } +} +``` +
+
DALL-E 2 image generation @@ -420,4 +461,4 @@ func main() { fmt.Println(resp.Choices[0].Message.Content) } ``` -
\ No newline at end of file + diff --git a/audio.go b/audio.go index 9db9298e3..46c37112b 100644 --- a/audio.go +++ b/audio.go @@ -13,6 +13,15 @@ const ( Whisper1 = "whisper-1" ) +// Response formats; Whisper uses AudioResponseFormatJSON by default. +type AudioResponseFormat string + +const ( + AudioResponseFormatJSON AudioResponseFormat = "json" + AudioResponseFormatSRT AudioResponseFormat = "srt" + AudioResponseFormatVTT AudioResponseFormat = "vtt" +) + // AudioRequest represents a request structure for audio API. // ResponseFormat is not supported for now. We only return JSON text, which may be sufficient. type AudioRequest struct { @@ -21,6 +30,7 @@ type AudioRequest struct { Prompt string // For translation, it should be in English Temperature float32 Language string // For translation, just do not use it. It seems "en" works, not confirmed... + Format AudioResponseFormat } // AudioResponse represents a response structure for audio API. @@ -66,10 +76,19 @@ func (c *Client) callAudioAPI( } req.Header.Add("Content-Type", builder.formDataContentType()) - err = c.sendRequest(req, &response) + if request.HasJSONResponse() { + err = c.sendRequest(req, &response) + } else { + err = c.sendRequest(req, &response.Text) + } return } +// HasJSONResponse returns true if the response format is JSON. +func (r AudioRequest) HasJSONResponse() bool { + return r.Format == "" || r.Format == AudioResponseFormatJSON +} + // audioMultipartForm creates a form with audio file contents and the name of the model to use for // audio processing. func audioMultipartForm(request AudioRequest, b formBuilder) error { @@ -97,6 +116,14 @@ func audioMultipartForm(request AudioRequest, b formBuilder) error { } } + // Create a form field for the format (if provided) + if request.Format != "" { + err = b.writeField("response_format", string(request.Format)) + if err != nil { + return fmt.Errorf("writing format: %w", err) + } + } + // Create a form field for the temperature (if provided) if request.Temperature != 0 { err = b.writeField("temperature", fmt.Sprintf("%.2f", request.Temperature)) diff --git a/audio_test.go b/audio_test.go index 9d2abfc50..daf51f28c 100644 --- a/audio_test.go +++ b/audio_test.go @@ -112,6 +112,7 @@ func TestAudioWithOptionalArgs(t *testing.T) { Prompt: "用简体中文", Temperature: 0.5, Language: "zh", + Format: AudioResponseFormatSRT, } _, err = tc.createFn(ctx, req) checks.NoError(t, err, "audio API error") @@ -179,6 +180,7 @@ func TestAudioWithFailingFormBuilder(t *testing.T) { Prompt: "test", Temperature: 0.5, Language: "en", + Format: AudioResponseFormatSRT, } mockFailedErr := fmt.Errorf("mock form builder fail") @@ -202,7 +204,7 @@ func TestAudioWithFailingFormBuilder(t *testing.T) { return nil } - failOn := []string{"model", "prompt", "temperature", "language"} + failOn := []string{"model", "prompt", "temperature", "language", "response_format"} for _, failingField := range failOn { failForField = failingField mockFailedErr = fmt.Errorf("mock form builder fail on field %s", failingField) diff --git a/client.go b/client.go index b15a18ae1..e17ded238 100644 --- a/client.go +++ b/client.go @@ -43,7 +43,7 @@ func NewOrgClient(authToken, org string) *Client { return NewClientWithConfig(config) } -func (c *Client) sendRequest(req *http.Request, v interface{}) error { +func (c *Client) sendRequest(req *http.Request, v any) error { req.Header.Set("Accept", "application/json; charset=utf-8") // Azure API Key authentication if c.config.APIType == APITypeAzure { @@ -75,12 +75,26 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error { return c.handleErrorResp(res) } - if v != nil { - if err = json.NewDecoder(res.Body).Decode(v); err != nil { - return err - } + return decodeResponse(res.Body, v) +} + +func decodeResponse(body io.Reader, v any) error { + if v == nil { + return nil } + if result, ok := v.(*string); ok { + return decodeString(body, result) + } + return json.NewDecoder(body).Decode(v) +} + +func decodeString(body io.Reader, output *string) error { + b, err := io.ReadAll(body) + if err != nil { + return err + } + *output = string(b) return nil } diff --git a/client_test.go b/client_test.go index 1c15985d6..7bea6dd87 100644 --- a/client_test.go +++ b/client_test.go @@ -1,6 +1,8 @@ package openai //nolint:testpackage // testing private field import ( + "bytes" + "io" "testing" ) @@ -20,3 +22,38 @@ func TestClient(t *testing.T) { t.Errorf("Client does not contain proper orgID") } } + +func TestDecodeResponse(t *testing.T) { + stringInput := "" + + testCases := []struct { + name string + value interface{} + body io.Reader + }{ + { + name: "nil input", + value: nil, + body: bytes.NewReader([]byte("")), + }, + { + name: "string input", + value: &stringInput, + body: bytes.NewReader([]byte("test")), + }, + { + name: "map input", + value: &map[string]interface{}{}, + body: bytes.NewReader([]byte(`{"test": "test"}`)), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := decodeResponse(tc.body, tc.value) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + }) + } +}