Skip to content

Commit

Permalink
Feat Implement threads API (sashabaranov#536)
Browse files Browse the repository at this point in the history
* feat: implement threads API

* fix

* add tests

* fix

* trigger£

* trigger

* chore: add beta header
  • Loading branch information
henomis authored Nov 9, 2023
1 parent 08c167f commit bc89139
Show file tree
Hide file tree
Showing 3 changed files with 214 additions and 0 deletions.
12 changes: 12 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,18 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) {
{"DeleteAssistantFile", func() (any, error) {
return nil, client.DeleteAssistantFile(ctx, "", "")
}},
{"CreateThread", func() (any, error) {
return client.CreateThread(ctx, ThreadRequest{})
}},
{"RetrieveThread", func() (any, error) {
return client.RetrieveThread(ctx, "")
}},
{"ModifyThread", func() (any, error) {
return client.ModifyThread(ctx, "", ModifyThreadRequest{})
}},
{"DeleteThread", func() (any, error) {
return client.DeleteThread(ctx, "")
}},
}

for _, testCase := range testCases {
Expand Down
107 changes: 107 additions & 0 deletions thread.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package openai

import (
"context"
"net/http"
)

const (
threadsSuffix = "/threads"
)

type Thread struct {
ID string `json:"id"`
Object string `json:"object"`
CreatedAt int64 `json:"created_at"`
Metadata map[string]any `json:"metadata"`

httpHeader
}

type ThreadRequest struct {
Messages []ThreadMessage `json:"messages,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
}

type ModifyThreadRequest struct {
Metadata map[string]any `json:"metadata"`
}

type ThreadMessageRole string

const (
ThreadMessageRoleUser ThreadMessageRole = "user"
)

type ThreadMessage struct {
Role ThreadMessageRole `json:"role"`
Content string `json:"content"`
FileIDs []string `json:"file_ids,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
}

type ThreadDeleteResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Deleted bool `json:"deleted"`

httpHeader
}

// CreateThread creates a new thread.
func (c *Client) CreateThread(ctx context.Context, request ThreadRequest) (response Thread, err error) {
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(threadsSuffix), withBody(request),
withBetaAssistantV1())
if err != nil {
return
}

err = c.sendRequest(req, &response)
return
}

// RetrieveThread retrieves a thread.
func (c *Client) RetrieveThread(ctx context.Context, threadID string) (response Thread, err error) {
urlSuffix := threadsSuffix + "/" + threadID
req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix),
withBetaAssistantV1())
if err != nil {
return
}

err = c.sendRequest(req, &response)
return
}

// ModifyThread modifies a thread.
func (c *Client) ModifyThread(
ctx context.Context,
threadID string,
request ModifyThreadRequest,
) (response Thread, err error) {
urlSuffix := threadsSuffix + "/" + threadID
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request),
withBetaAssistantV1())
if err != nil {
return
}

err = c.sendRequest(req, &response)
return
}

// DeleteThread deletes a thread.
func (c *Client) DeleteThread(
ctx context.Context,
threadID string,
) (response ThreadDeleteResponse, err error) {
urlSuffix := threadsSuffix + "/" + threadID
req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix),
withBetaAssistantV1())
if err != nil {
return
}

err = c.sendRequest(req, &response)
return
}
95 changes: 95 additions & 0 deletions thread_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package openai_test

import (
"context"
"encoding/json"
"fmt"
"net/http"
"testing"

openai "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
)

// TestThread Tests the thread endpoint of the API using the mocked server.
func TestThread(t *testing.T) {
threadID := "thread_abc123"
client, server, teardown := setupOpenAITestServer()
defer teardown()

server.RegisterHandler(
"/v1/threads/"+threadID,
func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
resBytes, _ := json.Marshal(openai.Thread{
ID: threadID,
Object: "thread",
CreatedAt: 1234567890,
})
fmt.Fprintln(w, string(resBytes))
case http.MethodPost:
var request openai.ThreadRequest
err := json.NewDecoder(r.Body).Decode(&request)
checks.NoError(t, err, "Decode error")

resBytes, _ := json.Marshal(openai.Thread{
ID: threadID,
Object: "thread",
CreatedAt: 1234567890,
})
fmt.Fprintln(w, string(resBytes))
case http.MethodDelete:
fmt.Fprintln(w, `{
"id": "thread_abc123",
"object": "thread.deleted",
"deleted": true
}`)
}
},
)

server.RegisterHandler(
"/v1/threads",
func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodPost {
var request openai.ModifyThreadRequest
err := json.NewDecoder(r.Body).Decode(&request)
checks.NoError(t, err, "Decode error")

resBytes, _ := json.Marshal(openai.Thread{
ID: threadID,
Object: "thread",
CreatedAt: 1234567890,
Metadata: request.Metadata,
})
fmt.Fprintln(w, string(resBytes))
}
},
)

ctx := context.Background()

_, err := client.CreateThread(ctx, openai.ThreadRequest{
Messages: []openai.ThreadMessage{
{
Role: openai.ThreadMessageRoleUser,
Content: "Hello, World!",
},
},
})
checks.NoError(t, err, "CreateThread error")

_, err = client.RetrieveThread(ctx, threadID)
checks.NoError(t, err, "RetrieveThread error")

_, err = client.ModifyThread(ctx, threadID, openai.ModifyThreadRequest{
Metadata: map[string]interface{}{
"key": "value",
},
})
checks.NoError(t, err, "ModifyThread error")

_, err = client.DeleteThread(ctx, threadID)
checks.NoError(t, err, "DeleteThread error")
}

0 comments on commit bc89139

Please sign in to comment.