Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -571,8 +571,8 @@ func (c *conn) execStagingOperation(
var err error

var isStagingOperation bool
if exStmtResp.DirectResults != nil && exStmtResp.DirectResults.ResultSetMetadata != nil && exStmtResp.DirectResults.ResultSetMetadata.IsStagingOperation != nil {
isStagingOperation = *exStmtResp.DirectResults.ResultSetMetadata.IsStagingOperation
if exStmtResp.DirectResults != nil && exStmtResp.DirectResults.ResultSetMetadata != nil {
isStagingOperation = exStmtResp.DirectResults.ResultSetMetadata.IsStagingOperation != nil && *exStmtResp.DirectResults.ResultSetMetadata.IsStagingOperation
} else {
req := cli_service.TGetResultSetMetadataReq{
OperationHandle: exStmtResp.OperationHandle,
Expand All @@ -581,7 +581,7 @@ func (c *conn) execStagingOperation(
if err != nil {
return dbsqlerrint.NewDriverError(ctx, "error performing staging operation", err)
}
isStagingOperation = *resp.IsStagingOperation
isStagingOperation = resp.IsStagingOperation != nil && *resp.IsStagingOperation
}

if !isStagingOperation {
Expand Down
94 changes: 94 additions & 0 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1677,6 +1677,100 @@ func TestConn_PrepareContext(t *testing.T) {
})
}

func TestConn_execStagingOperation(t *testing.T) {
t.Run("handles nil IsStagingOperation from DirectResults", func(t *testing.T) {
testClient := &client.TestClient{}
testConn := &conn{
session: getTestSession(),
client: testClient,
cfg: config.WithDefaults(),
}

// Create response with nil IsStagingOperation in DirectResults
exStmtResp := &cli_service.TExecuteStatementResp{
Status: &cli_service.TStatus{
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
},
OperationHandle: &cli_service.TOperationHandle{
OperationId: &cli_service.THandleIdentifier{
GUID: []byte{1, 2, 3, 4, 2, 23, 4, 2, 3, 1, 2, 3, 4, 223, 34, 54},
Secret: []byte("b"),
},
},
DirectResults: &cli_service.TSparkDirectResults{
ResultSetMetadata: &cli_service.TGetResultSetMetadataResp{
Status: &cli_service.TStatus{
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
},
// IsStagingOperation is nil
},
},
}

// Mock GetResultSetMetadata to return false for IsStagingOperation
var getResultSetMetadataCount int
getResultSetMetadata := func(ctx context.Context, req *cli_service.TGetResultSetMetadataReq) (_r *cli_service.TGetResultSetMetadataResp, _err error) {
getResultSetMetadataCount++
var falseVal = false
return &cli_service.TGetResultSetMetadataResp{
Status: &cli_service.TStatus{
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
},
IsStagingOperation: &falseVal,
}, nil
}
testClient.FnGetResultSetMetadata = getResultSetMetadata

ctx := context.Background()
err := testConn.execStagingOperation(exStmtResp, ctx)

assert.Nil(t, err)
assert.Equal(t, 0, getResultSetMetadataCount) // should not be called since DirectResults.ResultSetMetadata exists
})

t.Run("handles nil IsStagingOperation from GetResultSetMetadata", func(t *testing.T) {
testClient := &client.TestClient{}
testConn := &conn{
session: getTestSession(),
client: testClient,
cfg: config.WithDefaults(),
}

// Create response with nil DirectResults.ResultSetMetadata
exStmtResp := &cli_service.TExecuteStatementResp{
Status: &cli_service.TStatus{
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
},
OperationHandle: &cli_service.TOperationHandle{
OperationId: &cli_service.THandleIdentifier{
GUID: []byte{1, 2, 3, 4, 2, 23, 4, 2, 3, 1, 2, 3, 4, 223, 34, 54},
Secret: []byte("b"),
},
},
// DirectResults.ResultSetMetadata is nil
}

// Mock GetResultSetMetadata to return nil for IsStagingOperation
var getResultSetMetadataCount int
getResultSetMetadata := func(ctx context.Context, req *cli_service.TGetResultSetMetadataReq) (_r *cli_service.TGetResultSetMetadataResp, _err error) {
getResultSetMetadataCount++
return &cli_service.TGetResultSetMetadataResp{
Status: &cli_service.TStatus{
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
},
// IsStagingOperation is nil
}, nil
}
testClient.FnGetResultSetMetadata = getResultSetMetadata

ctx := context.Background()
err := testConn.execStagingOperation(exStmtResp, ctx)

assert.Nil(t, err)
assert.Equal(t, 1, getResultSetMetadataCount) // should be called since DirectResults.ResultSetMetadata is nil
})
}

func getTestSession() *cli_service.TOpenSessionResp {
return &cli_service.TOpenSessionResp{SessionHandle: &cli_service.TSessionHandle{
SessionId: &cli_service.THandleIdentifier{
Expand Down
Loading