-
-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
2 changed files
with
280 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
package openai_test | ||
|
||
import ( | ||
. "github.com/sashabaranov/go-openai" | ||
"github.com/sashabaranov/go-openai/internal/test" | ||
|
||
"context" | ||
"encoding/json" | ||
"errors" | ||
"io" | ||
"net/http" | ||
"net/http/httptest" | ||
"testing" | ||
) | ||
|
||
func TestCreateChatCompletionStream(t *testing.T) { | ||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
w.Header().Set("Content-Type", "text/event-stream") | ||
|
||
// Send test responses | ||
dataBytes := []byte{} | ||
dataBytes = append(dataBytes, []byte("event: message\n")...) | ||
//nolint:lll | ||
data := `{"id":"1","object":"completion","created":1598069254,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}]}` | ||
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) | ||
|
||
dataBytes = append(dataBytes, []byte("event: message\n")...) | ||
//nolint:lll | ||
data = `{"id":"2","object":"completion","created":1598069255,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}]}` | ||
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) | ||
|
||
dataBytes = append(dataBytes, []byte("event: done\n")...) | ||
dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) | ||
|
||
_, err := w.Write(dataBytes) | ||
if err != nil { | ||
t.Errorf("Write error: %s", err) | ||
} | ||
})) | ||
defer server.Close() | ||
|
||
// Client portion of the test | ||
config := DefaultConfig(test.GetTestToken()) | ||
config.BaseURL = server.URL + "/v1" | ||
config.HTTPClient.Transport = &tokenRoundTripper{ | ||
test.GetTestToken(), | ||
http.DefaultTransport, | ||
} | ||
|
||
client := NewClientWithConfig(config) | ||
ctx := context.Background() | ||
|
||
request := ChatCompletionRequest{ | ||
MaxTokens: 5, | ||
Model: GPT3Dot5Turbo, | ||
Messages: []ChatCompletionMessage{ | ||
{ | ||
Role: ChatMessageRoleUser, | ||
Content: "Hello!", | ||
}, | ||
}, | ||
Stream: true, | ||
} | ||
|
||
stream, err := client.CreateChatCompletionStream(ctx, request) | ||
if err != nil { | ||
t.Errorf("CreateCompletionStream returned error: %v", err) | ||
} | ||
defer stream.Close() | ||
|
||
expectedResponses := []ChatCompletionStreamResponse{ | ||
{ | ||
ID: "1", | ||
Object: "completion", | ||
Created: 1598069254, | ||
Model: GPT3Dot5Turbo, | ||
Choices: []ChatCompletionStreamChoice{ | ||
{ | ||
Delta: ChatCompletionStreamChoiceDelta{ | ||
Content: "response1", | ||
}, | ||
FinishReason: "max_tokens", | ||
}, | ||
}, | ||
}, | ||
{ | ||
ID: "2", | ||
Object: "completion", | ||
Created: 1598069255, | ||
Model: GPT3Dot5Turbo, | ||
Choices: []ChatCompletionStreamChoice{ | ||
{ | ||
Delta: ChatCompletionStreamChoiceDelta{ | ||
Content: "response2", | ||
}, | ||
FinishReason: "max_tokens", | ||
}, | ||
}, | ||
}, | ||
} | ||
|
||
for ix, expectedResponse := range expectedResponses { | ||
b, _ := json.Marshal(expectedResponse) | ||
t.Logf("%d: %s", ix, string(b)) | ||
|
||
receivedResponse, streamErr := stream.Recv() | ||
if streamErr != nil { | ||
t.Errorf("stream.Recv() failed: %v", streamErr) | ||
} | ||
if !compareChatResponses(expectedResponse, receivedResponse) { | ||
t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse) | ||
} | ||
} | ||
|
||
_, streamErr := stream.Recv() | ||
if !errors.Is(streamErr, io.EOF) { | ||
t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) | ||
} | ||
} | ||
|
||
// Helper funcs. | ||
func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool { | ||
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { | ||
return false | ||
} | ||
if len(r1.Choices) != len(r2.Choices) { | ||
return false | ||
} | ||
for i := range r1.Choices { | ||
if !compareChatStreamResponseChoices(r1.Choices[i], r2.Choices[i]) { | ||
return false | ||
} | ||
} | ||
return true | ||
} | ||
|
||
func compareChatStreamResponseChoices(c1, c2 ChatCompletionStreamChoice) bool { | ||
if c1.Index != c2.Index { | ||
return false | ||
} | ||
if c1.Delta.Content != c2.Delta.Content { | ||
return false | ||
} | ||
if c1.FinishReason != c2.FinishReason { | ||
return false | ||
} | ||
return true | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
package openai_test | ||
|
||
import ( | ||
. "github.com/sashabaranov/go-openai" | ||
"github.com/sashabaranov/go-openai/internal/test" | ||
|
||
"context" | ||
"encoding/json" | ||
"errors" | ||
"fmt" | ||
"io" | ||
"net/http" | ||
"strconv" | ||
"strings" | ||
"testing" | ||
"time" | ||
) | ||
|
||
func TestChatCompletionsWrongModel(t *testing.T) { | ||
config := DefaultConfig("whatever") | ||
config.BaseURL = "http://localhost/v1" | ||
client := NewClientWithConfig(config) | ||
ctx := context.Background() | ||
|
||
req := ChatCompletionRequest{ | ||
MaxTokens: 5, | ||
Model: "ada", | ||
Messages: []ChatCompletionMessage{ | ||
{ | ||
Role: ChatMessageRoleUser, | ||
Content: "Hello!", | ||
}, | ||
}, | ||
} | ||
_, err := client.CreateChatCompletion(ctx, req) | ||
if !errors.Is(err, ErrChatCompletionInvalidModel) { | ||
t.Fatalf("CreateChatCompletion should return wrong model error, but returned: %v", err) | ||
} | ||
} | ||
|
||
// TestCompletions Tests the completions endpoint of the API using the mocked server. | ||
func TestChatCompletions(t *testing.T) { | ||
server := test.NewTestServer() | ||
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) | ||
// create the test server | ||
var err error | ||
ts := server.OpenAITestServer() | ||
ts.Start() | ||
defer ts.Close() | ||
|
||
config := DefaultConfig(test.GetTestToken()) | ||
config.BaseURL = ts.URL + "/v1" | ||
client := NewClientWithConfig(config) | ||
ctx := context.Background() | ||
|
||
req := ChatCompletionRequest{ | ||
MaxTokens: 5, | ||
Model: GPT3Dot5Turbo, | ||
Messages: []ChatCompletionMessage{ | ||
{ | ||
Role: ChatMessageRoleUser, | ||
Content: "Hello!", | ||
}, | ||
}, | ||
} | ||
_, err = client.CreateChatCompletion(ctx, req) | ||
if err != nil { | ||
t.Fatalf("CreateChatCompletion error: %v", err) | ||
} | ||
} | ||
|
||
// handleChatCompletionEndpoint Handles the ChatGPT completion endpoint by the test server. | ||
func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { | ||
var err error | ||
var resBytes []byte | ||
|
||
// completions only accepts POST requests | ||
if r.Method != "POST" { | ||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) | ||
} | ||
var completionReq ChatCompletionRequest | ||
if completionReq, err = getChatCompletionBody(r); err != nil { | ||
http.Error(w, "could not read request", http.StatusInternalServerError) | ||
return | ||
} | ||
res := ChatCompletionResponse{ | ||
ID: strconv.Itoa(int(time.Now().Unix())), | ||
Object: "test-object", | ||
Created: time.Now().Unix(), | ||
// would be nice to validate Model during testing, but | ||
// this may not be possible with how much upkeep | ||
// would be required / wouldn't make much sense | ||
Model: completionReq.Model, | ||
} | ||
// create completions | ||
for i := 0; i < completionReq.N; i++ { | ||
// generate a random string of length completionReq.Length | ||
completionStr := strings.Repeat("a", completionReq.MaxTokens) | ||
|
||
res.Choices = append(res.Choices, ChatCompletionChoice{ | ||
Message: ChatCompletionMessage{ | ||
Role: ChatMessageRoleAssistant, | ||
Content: completionStr, | ||
}, | ||
Index: i, | ||
}) | ||
} | ||
inputTokens := numTokens(completionReq.Messages[0].Content) * completionReq.N | ||
completionTokens := completionReq.MaxTokens * completionReq.N | ||
res.Usage = Usage{ | ||
PromptTokens: inputTokens, | ||
CompletionTokens: completionTokens, | ||
TotalTokens: inputTokens + completionTokens, | ||
} | ||
resBytes, _ = json.Marshal(res) | ||
fmt.Fprintln(w, string(resBytes)) | ||
} | ||
|
||
// getChatCompletionBody Returns the body of the request to create a completion. | ||
func getChatCompletionBody(r *http.Request) (ChatCompletionRequest, error) { | ||
completion := ChatCompletionRequest{} | ||
// read the request body | ||
reqBody, err := io.ReadAll(r.Body) | ||
if err != nil { | ||
return ChatCompletionRequest{}, err | ||
} | ||
err = json.Unmarshal(reqBody, &completion) | ||
if err != nil { | ||
return ChatCompletionRequest{}, err | ||
} | ||
return completion, nil | ||
} |