diff --git a/fine_tunes.go b/fine_tunes.go new file mode 100644 index 000000000..82af2c082 --- /dev/null +++ b/fine_tunes.go @@ -0,0 +1,137 @@ +package openai + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" +) + +type FineTuneRequest struct { + TrainingFile string `json:"training_file"` + ValidationFile string `json:"validation_file,omitempty"` + Model string `json:"model,omitempty"` + Epochs int `json:"n_epochs,omitempty"` + BatchSize int `json:"batch_size,omitempty"` + LearningRateMultiplier float32 `json:"learning_rate_multiplier,omitempty"` + PromptLossRate float32 `json:"prompt_loss_rate,omitempty"` + ComputeClassificationMetrics bool `json:"compute_classification_metrics,omitempty"` + ClassificationClasses int `json:"classification_n_classes,omitempty"` + ClassificationPositiveClass string `json:"classification_positive_class,omitempty"` + ClassificationBetas []float32 `json:"classification_betas,omitempty"` + Suffix string `json:"suffix,omitempty"` +} + +type FineTune struct { + ID string `json:"id"` + Object string `json:"object"` + Model string `json:"model"` + CreatedAt int `json:"created_at"` + FineTunedModel string `json:"fine_tuned_model"` + Hyperparams FineTuneHyperParams `json:"hyperparams"` + OrganizationID string `json:"organization_id"` + ResultFiles []File `json:"result_files"` + Status string `json:"status"` + ValidationFiles []File `json:"validation_files"` + TrainingFiles []File `json:"training_files"` + UpdatedAt int `json:"updated_at"` +} + +type FineTuneEvent struct { + Object string `json:"object"` + CreatedAt int `json:"created_at"` + Level string `json:"level"` + Message string `json:"message"` +} + +type FineTuneHyperParams struct { + BatchSize int `json:"batch_size"` + LearningRateMultiplier float64 `json:"learning_rate_multiplier"` + Epochs int `json:"n_epochs"` + PromptLossWeight float64 `json:"prompt_loss_weight"` +} + +type FineTuneList struct { + Object string `json:"object"` + Data []FineTune `json:"data"` +} +type FineTuneEventList struct { + Object string `json:"object"` + Data []FineTuneEvent `json:"data"` +} + +type FineTuneDeleteResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Deleted bool `json:"deleted"` +} + +func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (response FineTune, err error) { + var reqBytes []byte + reqBytes, err = json.Marshal(request) + if err != nil { + return + } + + urlSuffix := "/fine-tunes" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// Cancel a fine-tune job. +func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel"), nil) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL("/fine-tunes"), nil) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { + urlSuffix := fmt.Sprintf("/fine-tunes/%s", fineTuneID) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (response FineTuneDeleteResponse, err error) { + req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.fullURL("/fine-tunes/"+fineTuneID), nil) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +func (c *Client) ListFineTuneEvents(ctx context.Context, fineTuneID string) (response FineTuneEventList, err error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL("/fine-tunes/"+fineTuneID+"/events"), nil) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +}