Skip to content

Commit

Permalink
feat: implement new fine tuning job API (sashabaranov#479)
Browse files Browse the repository at this point in the history
* feat: implement new fine tuning job API

* fix: export ListFineTuningJobEventsParameter

* fix: lint errors

* fix: test errors

* fix: code test coverage

* fix: code test coverage

* fix: use any

* chore: use url.Values
  • Loading branch information
henomis authored Aug 29, 2023
1 parent a14bc10 commit a2ca01b
Show file tree
Hide file tree
Showing 3 changed files with 255 additions and 0 deletions.
12 changes: 12 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,18 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) {
{"ListFineTuneEvents", func() (any, error) {
return client.ListFineTuneEvents(ctx, "")
}},
{"CreateFineTuningJob", func() (any, error) {
return client.CreateFineTuningJob(ctx, FineTuningJobRequest{})
}},
{"CancelFineTuningJob", func() (any, error) {
return client.CancelFineTuningJob(ctx, "")
}},
{"RetrieveFineTuningJob", func() (any, error) {
return client.RetrieveFineTuningJob(ctx, "")
}},
{"ListFineTuningJobEvents", func() (any, error) {
return client.ListFineTuningJobEvents(ctx, "")
}},
{"Moderations", func() (any, error) {
return client.Moderations(ctx, ModerationRequest{})
}},
Expand Down
153 changes: 153 additions & 0 deletions fine_tuning_job.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
package openai

import (
"context"
"fmt"
"net/http"
"net/url"
)

type FineTuningJob struct {
ID string `json:"id"`
Object string `json:"object"`
CreatedAt int64 `json:"created_at"`
FinishedAt int64 `json:"finished_at"`
Model string `json:"model"`
FineTunedModel string `json:"fine_tuned_model,omitempty"`
OrganizationID string `json:"organization_id"`
Status string `json:"status"`
Hyperparameters Hyperparameters `json:"hyperparameters"`
TrainingFile string `json:"training_file"`
ValidationFile string `json:"validation_file,omitempty"`
ResultFiles []string `json:"result_files"`
TrainedTokens int `json:"trained_tokens"`
}

type Hyperparameters struct {
Epochs int `json:"n_epochs"`
}

type FineTuningJobRequest struct {
TrainingFile string `json:"training_file"`
ValidationFile string `json:"validation_file,omitempty"`
Model string `json:"model,omitempty"`
Hyperparameters *Hyperparameters `json:"hyperparameters,omitempty"`
Suffix string `json:"suffix,omitempty"`
}

type FineTuningJobEventList struct {
Object string `json:"object"`
Data []FineTuneEvent `json:"data"`
HasMore bool `json:"has_more"`
}

type FineTuningJobEvent struct {
Object string `json:"object"`
ID string `json:"id"`
CreatedAt int `json:"created_at"`
Level string `json:"level"`
Message string `json:"message"`
Data any `json:"data"`
Type string `json:"type"`
}

// CreateFineTuningJob create a fine tuning job.
func (c *Client) CreateFineTuningJob(
ctx context.Context,
request FineTuningJobRequest,
) (response FineTuningJob, err error) {
urlSuffix := "/fine_tuning/jobs"
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request))
if err != nil {
return
}

err = c.sendRequest(req, &response)
return
}

// CancelFineTuningJob cancel a fine tuning job.
func (c *Client) CancelFineTuningJob(ctx context.Context, fineTuningJobID string) (response FineTuningJob, err error) {
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine_tuning/jobs/"+fineTuningJobID+"/cancel"))
if err != nil {
return
}

err = c.sendRequest(req, &response)
return
}

// RetrieveFineTuningJob retrieve a fine tuning job.
func (c *Client) RetrieveFineTuningJob(
ctx context.Context,
fineTuningJobID string,
) (response FineTuningJob, err error) {
urlSuffix := fmt.Sprintf("/fine_tuning/jobs/%s", fineTuningJobID)
req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix))
if err != nil {
return
}

err = c.sendRequest(req, &response)
return
}

type listFineTuningJobEventsParameters struct {
after *string
limit *int
}

type ListFineTuningJobEventsParameter func(*listFineTuningJobEventsParameters)

func ListFineTuningJobEventsWithAfter(after string) ListFineTuningJobEventsParameter {
return func(args *listFineTuningJobEventsParameters) {
args.after = &after
}
}

func ListFineTuningJobEventsWithLimit(limit int) ListFineTuningJobEventsParameter {
return func(args *listFineTuningJobEventsParameters) {
args.limit = &limit
}
}

// ListFineTuningJobs list fine tuning jobs events.
func (c *Client) ListFineTuningJobEvents(
ctx context.Context,
fineTuningJobID string,
setters ...ListFineTuningJobEventsParameter,
) (response FineTuningJobEventList, err error) {
parameters := &listFineTuningJobEventsParameters{
after: nil,
limit: nil,
}

for _, setter := range setters {
setter(parameters)
}

urlValues := url.Values{}
if parameters.after != nil {
urlValues.Add("after", *parameters.after)
}
if parameters.limit != nil {
urlValues.Add("limit", fmt.Sprintf("%d", *parameters.limit))
}

encodedValues := ""
if len(urlValues) > 0 {
encodedValues = "?" + urlValues.Encode()
}

req, err := c.newRequest(
ctx,
http.MethodGet,
c.fullURL("/fine_tuning/jobs/"+fineTuningJobID+"/events"+encodedValues),
)
if err != nil {
return
}

err = c.sendRequest(req, &response)
return
}
90 changes: 90 additions & 0 deletions fine_tuning_job_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package openai_test

import (
"context"

. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"

"encoding/json"
"fmt"
"net/http"
"testing"
)

const testFineTuninigJobID = "fine-tuning-job-id"

// TestFineTuningJob Tests the fine tuning job endpoint of the API using the mocked server.
func TestFineTuningJob(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler(
"/v1/fine_tuning/jobs",
func(w http.ResponseWriter, r *http.Request) {
var resBytes []byte
resBytes, _ = json.Marshal(FineTuningJob{})
fmt.Fprintln(w, string(resBytes))
},
)

server.RegisterHandler(
"/fine_tuning/jobs/"+testFineTuninigJobID+"/cancel",
func(w http.ResponseWriter, r *http.Request) {
resBytes, _ := json.Marshal(FineTuningJob{})
fmt.Fprintln(w, string(resBytes))
},
)

server.RegisterHandler(
"/v1/fine_tuning/jobs/"+testFineTuninigJobID,
func(w http.ResponseWriter, r *http.Request) {
var resBytes []byte
resBytes, _ = json.Marshal(FineTuningJob{})
fmt.Fprintln(w, string(resBytes))
},
)

server.RegisterHandler(
"/v1/fine_tuning/jobs/"+testFineTuninigJobID+"/events",
func(w http.ResponseWriter, r *http.Request) {
resBytes, _ := json.Marshal(FineTuningJobEventList{})
fmt.Fprintln(w, string(resBytes))
},
)

ctx := context.Background()

_, err := client.CreateFineTuningJob(ctx, FineTuningJobRequest{})
checks.NoError(t, err, "CreateFineTuningJob error")

_, err = client.CancelFineTuningJob(ctx, testFineTuninigJobID)
checks.NoError(t, err, "CancelFineTuningJob error")

_, err = client.RetrieveFineTuningJob(ctx, testFineTuninigJobID)
checks.NoError(t, err, "RetrieveFineTuningJob error")

_, err = client.ListFineTuningJobEvents(ctx, testFineTuninigJobID)
checks.NoError(t, err, "ListFineTuningJobEvents error")

_, err = client.ListFineTuningJobEvents(
ctx,
testFineTuninigJobID,
ListFineTuningJobEventsWithAfter("last-event-id"),
)
checks.NoError(t, err, "ListFineTuningJobEvents error")

_, err = client.ListFineTuningJobEvents(
ctx,
testFineTuninigJobID,
ListFineTuningJobEventsWithLimit(10),
)
checks.NoError(t, err, "ListFineTuningJobEvents error")

_, err = client.ListFineTuningJobEvents(
ctx,
testFineTuninigJobID,
ListFineTuningJobEventsWithAfter("last-event-id"),
ListFineTuningJobEventsWithLimit(10),
)
checks.NoError(t, err, "ListFineTuningJobEvents error")
}

0 comments on commit a2ca01b

Please sign in to comment.