Skip to content

Add dependency injection to ChatCompletionStream for improved testability #1011

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,21 @@
Usage *Usage `json:"usage,omitempty"`
}

// ChatStreamReader is an interface for reading chat completion streams.
type ChatStreamReader interface {
Recv() (ChatCompletionStreamResponse, error)
Close() error
}

// ChatCompletionStream
// Note: Perhaps it is more elegant to abstract Stream using generics.
type ChatCompletionStream struct {
*streamReader[ChatCompletionStreamResponse]
reader ChatStreamReader
}

// NewChatCompletionStream allows injecting a custom ChatStreamReader (for testing).
func NewChatCompletionStream(reader ChatStreamReader) *ChatCompletionStream {
return &ChatCompletionStream{reader: reader}
Comment on lines +68 to +82
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sashabaranov Does this approach fulfill your intent in the Note comment?

}

// CreateChatCompletionStream — API call to create a chat completion w/ streaming
Expand Down Expand Up @@ -106,7 +117,37 @@
return
}
stream = &ChatCompletionStream{
streamReader: resp,
reader: resp,
}
return
}

func (s *ChatCompletionStream) Recv() (ChatCompletionStreamResponse, error) {
return s.reader.Recv()
}

func (s *ChatCompletionStream) Close() error {
return s.reader.Close()
}

func (s *ChatCompletionStream) Header() http.Header {
if h, ok := s.reader.(interface{ Header() http.Header }); ok {
return h.Header()
}
return http.Header{}

Check warning on line 137 in chat_stream.go

View check run for this annotation

Codecov / codecov/patch

chat_stream.go#L137

Added line #L137 was not covered by tests
}

func (s *ChatCompletionStream) GetRateLimitHeaders() map[string]interface{} {
if h, ok := s.reader.(interface{ GetRateLimitHeaders() RateLimitHeaders }); ok {
headers := h.GetRateLimitHeaders()
return map[string]interface{}{
"x-ratelimit-limit-requests": headers.LimitRequests,
"x-ratelimit-limit-tokens": headers.LimitTokens,
"x-ratelimit-remaining-requests": headers.RemainingRequests,
"x-ratelimit-remaining-tokens": headers.RemainingTokens,
"x-ratelimit-reset-requests": headers.ResetRequests.String(),
"x-ratelimit-reset-tokens": headers.ResetTokens.String(),
}
}
return map[string]interface{}{}

Check warning on line 152 in chat_stream.go

View check run for this annotation

Codecov / codecov/patch

chat_stream.go#L152

Added line #L152 was not covered by tests
}
28 changes: 28 additions & 0 deletions chat_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,34 @@ func TestCreateChatCompletionStreamStreamOptions(t *testing.T) {
}
}

type mockStream struct {
calls int
}

// Implement ChatStreamReader.
func (m *mockStream) Recv() (openai.ChatCompletionStreamResponse, error) {
m.calls++
if m.calls == 1 {
return openai.ChatCompletionStreamResponse{ID: "mock1"}, nil
}
return openai.ChatCompletionStreamResponse{}, io.EOF
}
func (m *mockStream) Close() error { return nil }

func TestChatCompletionStream_MockInjection(t *testing.T) {
mock := &mockStream{}
stream := openai.NewChatCompletionStream(mock)

resp, err := stream.Recv()
if err != nil || resp.ID != "mock1" {
t.Errorf("expected mock1, got %v, err %v", resp.ID, err)
}
_, err = stream.Recv()
if !errors.Is(err, io.EOF) {
t.Errorf("expected EOF, got %v", err)
}
}

// Helper funcs.
func compareChatResponses(r1, r2 openai.ChatCompletionStreamResponse) bool {
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {
Expand Down
199 changes: 199 additions & 0 deletions mock_streaming_demo_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
package openai_test

import (
"context"
"errors"
"io"
"testing"

"github.com/sashabaranov/go-openai"
)

// This file demonstrates how to create mock clients for go-openai streaming
// functionality. This pattern is useful when testing code that depends on
// go-openai streaming but you want to control the responses for testing.

// MockOpenAIStreamClient demonstrates how to create a full mock client for go-openai.
type MockOpenAIStreamClient struct {
// Configure canned responses
ChatCompletionResponse openai.ChatCompletionResponse
ChatCompletionStreamErr error

// Allow function overrides for more complex scenarios
CreateChatCompletionStreamFn func(
ctx context.Context, req openai.ChatCompletionRequest) (*openai.ChatCompletionStream, error)
}

func (m *MockOpenAIStreamClient) CreateChatCompletionStream(
ctx context.Context,
req openai.ChatCompletionRequest,
) (*openai.ChatCompletionStream, error) {
if m.CreateChatCompletionStreamFn != nil {
return m.CreateChatCompletionStreamFn(ctx, req)
}
return nil, m.ChatCompletionStreamErr
}

// mockStreamReader creates specific responses for testing.
type mockStreamReader struct {
responses []openai.ChatCompletionStreamResponse
index int
}

func (m *mockStreamReader) Recv() (openai.ChatCompletionStreamResponse, error) {
if m.index >= len(m.responses) {
return openai.ChatCompletionStreamResponse{}, io.EOF
}
resp := m.responses[m.index]
m.index++
return resp, nil
}

func (m *mockStreamReader) Close() error {
return nil
}

func TestMockOpenAIStreamClient_Demo(t *testing.T) {
// Create expected responses that our mock stream will return
expectedResponses := []openai.ChatCompletionStreamResponse{
{
ID: "test-1",
Object: "chat.completion.chunk",
Model: "gpt-3.5-turbo",
Choices: []openai.ChatCompletionStreamChoice{
{
Index: 0,
Delta: openai.ChatCompletionStreamChoiceDelta{
Role: "assistant",
Content: "Hello",
},
},
},
},
{
ID: "test-2",
Object: "chat.completion.chunk",
Model: "gpt-3.5-turbo",
Choices: []openai.ChatCompletionStreamChoice{
{
Index: 0,
Delta: openai.ChatCompletionStreamChoiceDelta{
Content: " World",
},
},
},
},
{
ID: "test-3",
Object: "chat.completion.chunk",
Model: "gpt-3.5-turbo",
Choices: []openai.ChatCompletionStreamChoice{
{
Index: 0,
Delta: openai.ChatCompletionStreamChoiceDelta{},
FinishReason: "stop",
},
},
},
}

// Create mock client with custom stream function
mockClient := &MockOpenAIStreamClient{
CreateChatCompletionStreamFn: func(
_ context.Context, _ openai.ChatCompletionRequest,
) (*openai.ChatCompletionStream, error) {
// Create a mock stream reader with our expected responses
mockStreamReader := &mockStreamReader{
responses: expectedResponses,
index: 0,
}
// Return a new ChatCompletionStream with our mock reader
return openai.NewChatCompletionStream(mockStreamReader), nil
},
}

// Test the mock client
stream, err := mockClient.CreateChatCompletionStream(
context.Background(),
openai.ChatCompletionRequest{
Model: openai.GPT3Dot5Turbo,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: "Hello!",
},
},
},
)
if err != nil {
t.Fatalf("CreateChatCompletionStream returned error: %v", err)
}
defer stream.Close()

// Verify we get back exactly the responses we configured
fullResponse := ""
for i, expectedResponse := range expectedResponses {
receivedResponse, streamErr := stream.Recv()
if streamErr != nil {
t.Fatalf("stream.Recv() failed at index %d: %v", i, streamErr)
}

// Additional specific checks
if receivedResponse.ID != expectedResponse.ID {
t.Errorf("Response %d ID mismatch. Expected: %s, Got: %s",
i, expectedResponse.ID, receivedResponse.ID)
}
if len(receivedResponse.Choices) > 0 && len(expectedResponse.Choices) > 0 {
expectedContent := expectedResponse.Choices[0].Delta.Content
receivedContent := receivedResponse.Choices[0].Delta.Content
if receivedContent != expectedContent {
t.Errorf("Response %d content mismatch. Expected: %s, Got: %s",
i, expectedContent, receivedContent)
}
fullResponse += receivedContent
}
}

// Verify EOF at the end
_, streamErr := stream.Recv()
if !errors.Is(streamErr, io.EOF) {
t.Errorf("Expected EOF at end of stream, got: %v", streamErr)
}

// Verify the full assembled response
expectedFullResponse := "Hello World"
if fullResponse != expectedFullResponse {
t.Errorf("Full response mismatch. Expected: %s, Got: %s", expectedFullResponse, fullResponse)
}

t.Log("✅ Successfully demonstrated mock OpenAI client with streaming responses!")
t.Logf(" Full response assembled: %q", fullResponse)
}

// TestMockOpenAIStreamClient_ErrorHandling demonstrates error handling.
func TestMockOpenAIStreamClient_ErrorHandling(t *testing.T) {
expectedError := errors.New("mock stream error")

mockClient := &MockOpenAIStreamClient{
ChatCompletionStreamErr: expectedError,
}

_, err := mockClient.CreateChatCompletionStream(
context.Background(),
openai.ChatCompletionRequest{
Model: openai.GPT3Dot5Turbo,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: "Hello!",
},
},
},
)

if !errors.Is(err, expectedError) {
t.Errorf("Expected error %v, got %v", expectedError, err)
}

t.Log("✅ Successfully demonstrated mock OpenAI client error handling!")
}
2 changes: 2 additions & 0 deletions stream_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ var (
errorPrefix = regexp.MustCompile(`^data:\s*{"error":`)
)

var _ ChatStreamReader = (*streamReader[ChatCompletionStreamResponse])(nil)

type streamable interface {
ChatCompletionStreamResponse | CompletionResponse
}
Expand Down
Loading