Skip to content

Commit

Permalink
add optional params for audio api, e.g. prompt (sashabaranov#183)
Browse files Browse the repository at this point in the history
* Compatible with the situation where the mask is empty in CreateEditImage.

* Fix the test for the unnecessary removal of the mask.png file.

* add image variation implementation

* fix image variation bugs

* fix ci-lint problem with max line character limit

* add offitial doc link

* just for codeball test

* fix lint problem

* add optional params for audio api, e.g. prompt

* add comment for new args in translation
  • Loading branch information
itegel authored Mar 20, 2023
1 parent d529d13 commit aa149c1
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 2 deletions.
49 changes: 47 additions & 2 deletions audio.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@ const (
)

// AudioRequest represents a request structure for audio API.
// ResponseFormat is not supported for now. We only return JSON text, which may be sufficient.
type AudioRequest struct {
Model string
FilePath string
Model string
FilePath string
Prompt string // For translation, it should be in English
Temperature float32
Language string // For translation, just do not use it. It seems "en" works, not confirmed...
}

// AudioResponse represents a response structure for audio API.
Expand Down Expand Up @@ -94,6 +98,47 @@ func audioMultipartForm(request AudioRequest, w *multipart.Writer) error {
if _, err = io.Copy(fw, modelName); err != nil {
return fmt.Errorf("writing model name: %w", err)
}

// Create a form field for the prompt (if provided)
if request.Prompt != "" {
fw, err = w.CreateFormField("prompt")
if err != nil {
return fmt.Errorf("creating form field: %w", err)
}

prompt := bytes.NewReader([]byte(request.Prompt))
if _, err = io.Copy(fw, prompt); err != nil {
return fmt.Errorf("writing prompt: %w", err)
}
}

// Create a form field for the temperature (if provided)
if request.Temperature != 0 {
fw, err = w.CreateFormField("temperature")
if err != nil {
return fmt.Errorf("creating form field: %w", err)
}

temperature := bytes.NewReader([]byte(fmt.Sprintf("%.2f", request.Temperature)))
if _, err = io.Copy(fw, temperature); err != nil {
return fmt.Errorf("writing temperature: %w", err)
}
}

// Create a form field for the language (if provided)
if request.Language != "" {
fw, err = w.CreateFormField("language")
if err != nil {
return fmt.Errorf("creating form field: %w", err)
}

language := bytes.NewReader([]byte(request.Language))
if _, err = io.Copy(fw, language); err != nil {
return fmt.Errorf("writing language: %w", err)
}
}

// Close the multipart writer
w.Close()

return nil
Expand Down
53 changes: 53 additions & 0 deletions audio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,59 @@ func TestAudio(t *testing.T) {
}
}

func TestAudioWithOptionalArgs(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint)
server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint)
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)

testcases := []struct {
name string
createFn func(context.Context, AudioRequest) (AudioResponse, error)
}{
{
"transcribe",
client.CreateTranscription,
},
{
"translate",
client.CreateTranslation,
},
}

ctx := context.Background()

dir, cleanup := createTestDirectory(t)
defer cleanup()

for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
path := filepath.Join(dir, "fake.mp3")
createTestFile(t, path)

req := AudioRequest{
FilePath: path,
Model: "whisper-3",
Prompt: "用简体中文",
Temperature: 0.5,
Language: "zh",
}
_, err = tc.createFn(ctx, req)
if err != nil {
t.Fatalf("audio API error: %v", err)
}
})
}
}

// createTestFile creates a fake file with "hello" as the content.
func createTestFile(t *testing.T, path string) {
file, err := os.Create(path)
Expand Down

0 comments on commit aa149c1

Please sign in to comment.