diff --git a/.golangci.yml b/.golangci.yml index 8d77830b1..4dfb106cb 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,2 +1,10 @@ run: - tests: true \ No newline at end of file + tests: true + +issues: + exclude-rules: + # we own this driver so we need to test this functions + - path: statement_test.go + text: "Drivers should implement StmtExecContext instead" + - path: statement_test.go + text: "Drivers should implement StmtQueryContext instead" \ No newline at end of file diff --git a/aaa_test.go b/aaa_test.go index 2f633764f..60eaf6d83 100644 --- a/aaa_test.go +++ b/aaa_test.go @@ -7,15 +7,13 @@ import ( func TestShowServerVersion(t *testing.T) { runDBTest(t, func(dbt *DBTest) { rows := dbt.mustQuery("SELECT CURRENT_VERSION()") - defer func(rows *RowsExtended) { - err := rows.Close() - assertNilF(t, err) - }(rows) + defer func() { + assertNilF(t, rows.Close()) + }() var version string rows.Next() - err := rows.Scan(&version) - assertNilF(t, err) + assertNilF(t, rows.Scan(&version)) println(version) }) } diff --git a/authexternalbrowser.go b/authexternalbrowser.go index b7fd63bf0..d973ec0c9 100644 --- a/authexternalbrowser.go +++ b/authexternalbrowser.go @@ -304,25 +304,23 @@ func doAuthenticateByExternalBrowser( if n < bufSize { // We successfully read all data s := string(buf.Bytes()[:total]) - encodedSamlResponse, err = getTokenFromResponse(s) - if err != nil { - errChan <- err - } + encodedSamlResponse, errAccept = getTokenFromResponse(s) break } buf.Grow(bufSize) } if encodedSamlResponse != "" { httpResponse, err := buildResponse(application) - if err != nil { - errChan <- err - return + if err != nil && errAccept == nil { + errAccept = err } - _, err = c.Write(httpResponse.Bytes()) - errChan <- err - return + if _, err = c.Write(httpResponse.Bytes()); err != nil { + errAccept = err + } + } + if err := c.Close(); err != nil { + logger.Warnf("error while closing browser connection. %v", err) } - err = c.Close() encodedSamlResponseChan <- encodedSamlResponse errChan <- errAccept }(conn) diff --git a/bindings_test.go b/bindings_test.go index d49c2f4f5..91530dc5e 100644 --- a/bindings_test.go +++ b/bindings_test.go @@ -69,7 +69,9 @@ func TestBindingFloat64(t *testing.T) { dbt.mustExec(fmt.Sprintf("CREATE OR REPLACE TABLE test (id int, value %v)", v)) dbt.mustExec("INSERT INTO test VALUES (1, ?)", expected) rows = dbt.mustQuery("SELECT value FROM test WHERE id = ?", 1) - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() if rows.Next() { assertNilF(t, rows.Scan(&out)) if expected != out { @@ -200,12 +202,16 @@ func TestBindingTimestampTZ(t *testing.T) { if err != nil { dbt.Fatal(err.Error()) } - defer assertNilF(t, stmt.Close()) + defer func() { + assertNilF(t, stmt.Close()) + }() if _, err = stmt.Exec(DataTypeTimestampTz, expected); err != nil { dbt.Fatal(err) } rows := dbt.mustQuery("SELECT tz FROM tztest WHERE id=?", 1) - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() var v time.Time if rows.Next() { assertNilF(t, rows.Scan(&v)) @@ -251,7 +257,9 @@ func TestBindingTimePtrInStruct(t *testing.T) { } rows := dbt.mustQuery("SELECT tz FROM timeStructTest WHERE id=?", &expectedID) - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() var v time.Time if rows.Next() { assertNilF(t, rows.Scan(&v)) @@ -298,7 +306,9 @@ func TestBindingTimeInStruct(t *testing.T) { } rows := dbt.mustQuery("SELECT tz FROM timeStructTest WHERE id=?", &expectedID) - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() var v time.Time if rows.Next() { assertNilF(t, rows.Scan(&v)) @@ -318,7 +328,9 @@ func TestBindingInterface(t *testing.T) { runDBTest(t, func(dbt *DBTest) { rows := dbt.mustQueryContext( WithHigherPrecision(context.Background()), selectVariousTypes) - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() if !rows.Next() { dbt.Error("failed to query") } @@ -344,7 +356,9 @@ func TestBindingInterface(t *testing.T) { func TestBindingInterfaceString(t *testing.T) { runDBTest(t, func(dbt *DBTest) { rows := dbt.mustQuery(selectVariousTypes) - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() if !rows.Next() { dbt.Error("failed to query") } @@ -381,7 +395,9 @@ func TestBulkArrayBindingInterfaceNil(t *testing.T) { Array(&nilArray, TimestampTZType), Array(&nilArray, DateType), Array(&nilArray, TimeType)) rows := dbt.mustQuery(selectAllSQL) - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() var v0 sql.NullInt32 var v1 sql.NullFloat64 @@ -464,7 +480,9 @@ func TestBulkArrayBindingInterface(t *testing.T) { dbt.mustExec(insertSQLBulkArray, Array(&intArray), Array(&fltArray), Array(&boolArray), Array(&strArray), Array(&byteArray), Array(&int64Array)) rows := dbt.mustQuery(selectAllSQLBulkArray) - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() var v0 sql.NullInt32 var v1 sql.NullFloat64 @@ -567,7 +585,9 @@ func TestBulkArrayBindingInterfaceDateTimeTimestamp(t *testing.T) { Array(&tmArray, TimeType)) rows := dbt.mustQuery(selectAllSQLBulkArrayDateTimeTimestamp) - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() var v0, v1, v2, v3, v4 sql.NullTime @@ -674,7 +694,9 @@ func testBindingArray(t *testing.T, bulk bool) { Array(&tzArray, TimestampTZType), Array(&dtArray, DateType), Array(&tmArray, TimeType)) rows := dbt.mustQuery(selectAllSQL) - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() var v0 int var v1 float64 @@ -754,7 +776,9 @@ func TestBulkArrayBinding(t *testing.T) { } dbt.mustExec(fmt.Sprintf("insert into %v values (?, ?, ?, ?, ?, ?, ?, ?)", dbname), Array(&intArr), Array(&strArr), Array(<zArr, TimestampLTZType), Array(&tzArr, TimestampTZType), Array(&ntzArr, TimestampNTZType), Array(&dateArr, DateType), Array(&timeArr, TimeType), Array(&binArr)) rows := dbt.mustQuery("select * from " + dbname + " order by c1") - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() cnt := 0 var i int var s string @@ -800,7 +824,9 @@ func TestBulkArrayBindingTimeWithPrecision(t *testing.T) { } dbt.mustExec(fmt.Sprintf("insert into %v values (?, ?, ?, ?)", dbname), Array(&secondsArr, TimeType), Array(&millisecondsArr, TimeType), Array(µsecondsArr, TimeType), Array(&nanosecondsArr, TimeType)) rows := dbt.mustQuery("select * from " + dbname) - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() cnt := 0 var s, ms, us, ns time.Time for rows.Next() { @@ -839,7 +865,9 @@ func TestBulkArrayMultiPartBinding(t *testing.T) { fmt.Sprintf("INSERT INTO %s VALUES (?)", tempTableName), Array(&randomStrings)) rows := dbt.mustQuery("select count(*) from " + tempTableName) - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() if rows.Next() { var count int if err := rows.Scan(&count); err != nil { @@ -849,7 +877,9 @@ func TestBulkArrayMultiPartBinding(t *testing.T) { } rows := dbt.mustQuery("select count(*) from " + tempTableName) - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() if rows.Next() { var count int if err := rows.Scan(&count); err != nil { @@ -878,7 +908,9 @@ func TestBulkArrayMultiPartBindingInt(t *testing.T) { } rows := dbt.mustQuery("select * from binding_test order by c1") - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() cnt := startNum var i int for rows.Next() { @@ -926,7 +958,9 @@ func TestBulkArrayMultiPartBindingWithNull(t *testing.T) { } rows := dbt.mustQuery("select * from binding_test order by c1,c2") - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() cnt := startNum var i sql.NullInt32 var s sql.NullString @@ -1007,7 +1041,9 @@ func TestFunctionParameters(t *testing.T) { if err != nil { t.Fatal(err) } - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() if rows.Err() != nil { t.Fatal(err) } @@ -1107,7 +1143,9 @@ func TestVariousBindingModes(t *testing.T) { if err != nil { t.Fatal(err) } - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() if !rows.Next() { t.Fatal("Expected to return a row") } @@ -1155,7 +1193,9 @@ func testLOBRetrieval(t *testing.T, useArrowFormat bool) { t.Run(fmt.Sprintf("testLOB_%v_useArrowFormat=%v", strconv.Itoa(testSize), strconv.FormatBool(useArrowFormat)), func(t *testing.T) { rows, err := dbt.query(fmt.Sprintf("SELECT randstr(%v, 124)", testSize)) assertNilF(t, err) - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() assertTrueF(t, rows.Next(), fmt.Sprintf("no rows returned for the LOB size %v", testSize)) // retrieve the result @@ -1186,7 +1226,9 @@ func TestMaxLobSize(t *testing.T) { dbt.mustExec(enableLargeVarcharAndBinary) rows, err := dbt.query("select randstr(20000000, random())") assertNilF(t, err) - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() }) }) } @@ -1265,7 +1307,9 @@ func testInsertLOBData(t *testing.T, useArrowFormat bool, isLiteral bool) { } rows, err := dbt.query("SELECT * FROM lob_test_table") assertNilF(t, err) - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() assertTrueF(t, rows.Next(), fmt.Sprintf("%s: no rows returned", tc.testDesc)) err = rows.Scan(&c1, &c2, &c3) diff --git a/chunk_downloader_test.go b/chunk_downloader_test.go index 487cfe76f..5c5ec3abe 100644 --- a/chunk_downloader_test.go +++ b/chunk_downloader_test.go @@ -99,9 +99,11 @@ func TestWithArrowBatchesWhenQueryReturnsSomeRowsInGivenFormatUsingNativeGoSQLIn return err }) assertNilF(t, err) - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() values := make([]driver.Value, 1) - assertNilF(t, rows.Next(values)) + assertNotNilE(t, rows.Next(values)) // we deliberately check that there is an error, because we are in arrow batches mode assertEqualE(t, values[0], nil) }) }) @@ -128,7 +130,9 @@ func TestWithArrowBatchesWhenQueryReturnsSomeRowsInGivenFormat(t *testing.T) { dbt.mustExec(forceJSON) } rows := dbt.mustQueryContext(WithArrowBatches(context.Background()), "SELECT 1") - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() assertFalseF(t, rows.Next()) }) }) diff --git a/connection.go b/connection.go index 9eb288a47..b66f3e608 100644 --- a/connection.go +++ b/connection.go @@ -278,7 +278,7 @@ func (sc *snowflakeConn) cleanup() { func (sc *snowflakeConn) Close() (err error) { logger.WithContext(sc.ctx).Infoln("Close") if err := sc.telemetry.sendBatch(); err != nil { - logger.WithContext(sc.ctx).Error(err) + logger.WithContext(sc.ctx).Errorf("error while sending telemetry. %v", err) } sc.stopHeartBeat() defer sc.cleanup() diff --git a/connection_util.go b/connection_util.go index 170c42351..3c216785a 100644 --- a/connection_util.go +++ b/connection_util.go @@ -133,6 +133,9 @@ func (sc *snowflakeConn) processFileTransfer( func getFileStream(ctx context.Context) (*bytes.Buffer, error) { s := ctx.Value(fileStreamFile) + if s == nil { + return nil, nil + } r, ok := s.(io.Reader) if !ok { return nil, errors.New("incorrect io.Reader") diff --git a/converter_test.go b/converter_test.go index 27e4d5cf9..30510280a 100644 --- a/converter_test.go +++ b/converter_test.go @@ -2174,7 +2174,9 @@ func TestSmallTimestampBinding(t *testing.T) { } rows := sct.mustQueryContext(ctx, "SELECT ?", parameters) - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() scanValues := make([]driver.Value, 1) for { @@ -2210,7 +2212,9 @@ func TestTimestampConversionWithoutArrowBatches(t *testing.T) { t.Run(tp+"("+strconv.Itoa(scale)+")_"+tsStr, func(t *testing.T) { query := fmt.Sprintf("SELECT '%s'::%s(%v)", tsStr, tp, scale) rows := sct.mustQueryContext(ctx, query, nil) - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() if rows.Next() { var act time.Time @@ -2290,7 +2294,9 @@ func TestTimestampConversionWithArrowBatchesMicrosecondPassesForDistantDates(t * if err != nil { t.Fatalf("failed to query: %v", err) } - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() // getting result batches batches, err := rows.(*snowflakeRows).GetArrowBatches() @@ -2349,7 +2355,9 @@ func TestTimestampConversionWithArrowBatchesAndWithOriginalTimestamp(t *testing. query := fmt.Sprintf("SELECT '%s'::%s(%v)", tsStr, tp, scale) rows := sct.mustQueryContext(ctx, query, []driver.NamedValue{}) - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() // getting result batches batches, err := rows.(*snowflakeRows).GetArrowBatches() diff --git a/driver_test.go b/driver_test.go index 48b1ceb75..f7ee5f0b9 100644 --- a/driver_test.go +++ b/driver_test.go @@ -629,7 +629,9 @@ func TestCRUD(t *testing.T) { // Read rows = dbt.mustQuery("SELECT value FROM test") - defer assertNilF(t, rows.Close()) + defer func(rows *RowsExtended) { + assertNilF(t, rows.Close()) + }(rows) if rows.Next() { assertNilF(t, rows.Scan(&out)) if !out { @@ -654,7 +656,9 @@ func TestCRUD(t *testing.T) { // Check Update rows = dbt.mustQuery("SELECT value FROM test") - defer assertNilF(t, rows.Close()) + defer func(rows *RowsExtended) { + assertNilF(t, rows.Close()) + }(rows) if rows.Next() { assertNilF(t, rows.Scan(&out)) if out { @@ -709,7 +713,9 @@ func testInt(t *testing.T, json bool) { dbt.mustExec("CREATE OR REPLACE TABLE test (value " + v + ")") dbt.mustExec("INSERT INTO test VALUES (?)", in) rows = dbt.mustQuery("SELECT value FROM test") - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() if rows.Next() { assertNilF(t, rows.Scan(&out)) if in != out { @@ -743,7 +749,9 @@ func testFloat32(t *testing.T, json bool) { dbt.mustExec("CREATE OR REPLACE TABLE test (value " + v + ")") dbt.mustExec("INSERT INTO test VALUES (?)", in) rows = dbt.mustQuery("SELECT value FROM test") - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() if rows.Next() { err := rows.Scan(&out) if err != nil { @@ -779,7 +787,9 @@ func testFloat64(t *testing.T, json bool) { dbt.mustExec("CREATE OR REPLACE TABLE test (value " + v + ")") dbt.mustExec("INSERT INTO test VALUES (42.23)") rows = dbt.mustQuery("SELECT value FROM test") - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() if rows.Next() { assertNilF(t, rows.Scan(&out)) if expected != out { @@ -814,7 +824,9 @@ func testString(t *testing.T, json bool) { dbt.mustExec("INSERT INTO test VALUES (?)", in) rows = dbt.mustQuery("SELECT value FROM test") - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() if rows.Next() { assertNilF(t, rows.Scan(&out)) if in != out { @@ -1282,11 +1294,13 @@ func testNULL(t *testing.T, json bool) { dbt.mustExec("CREATE OR REPLACE TABLE test (dummmy1 int, value int, dummy2 int)") dbt.mustExec("INSERT INTO test VALUES (?, ?, ?)", 1, nil, 2) - var out interface{} + var dummy1, out, dummy2 interface{} rows := dbt.mustQuery("SELECT * FROM test") - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() if rows.Next() { - assertNilF(t, rows.Scan(&out)) + assertNilF(t, rows.Scan(&dummy1, &out, &dummy2)) if out != nil { dbt.Errorf("%v != nil", out) } diff --git a/encrypt_util.go b/encrypt_util.go index d5f6012d7..d4996535b 100644 --- a/encrypt_util.go +++ b/encrypt_util.go @@ -154,7 +154,7 @@ func encryptFileCBC( filename string, chunkSize int, tmpDir string) ( - *encryptMetadata, string, error) { + meta *encryptMetadata, fileName string, err error) { if chunkSize == 0 { chunkSize = aes.BlockSize * 4 * 1024 } @@ -177,7 +177,7 @@ func encryptFileCBC( } }() - meta, err := encryptStreamCBC(sfe, infile, tmpOutputFile, chunkSize) + meta, err = encryptStreamCBC(sfe, infile, tmpOutputFile, chunkSize) if err != nil { return nil, "", err } @@ -228,7 +228,7 @@ func decryptFileCBC( sfe *snowflakeFileEncryption, filename string, chunkSize int, - tmpDir string) (string, error) { + tmpDir string) (outputFileName string, err error) { tmpOutputFile, err := os.CreateTemp(tmpDir, baseName(filename)+"#") if err != nil { return "", err @@ -330,7 +330,7 @@ func encryptFileGCM( sfe *snowflakeFileEncryption, filename string, tmpDir string) ( - *gcmEncryptMetadata, string, error) { + meta *gcmEncryptMetadata, outputFileName string, err error) { tmpOutputFile, err := os.CreateTemp(tmpDir, baseName(filename)+"#") if err != nil { return nil, "", err @@ -386,7 +386,7 @@ func encryptFileGCM( if err != nil { return nil, "", err } - meta := &gcmEncryptMetadata{ + meta = &gcmEncryptMetadata{ key: base64.StdEncoding.EncodeToString(encryptedFileKey), keyIv: base64.StdEncoding.EncodeToString(keyIv), dataIv: base64.StdEncoding.EncodeToString(dataIv), diff --git a/file_transfer_agent.go b/file_transfer_agent.go index 9592ea205..aa40282c7 100644 --- a/file_transfer_agent.go +++ b/file_transfer_agent.go @@ -35,7 +35,6 @@ type ( const ( fileProtocol = "file://" dataSizeThreshold int64 = 64 * 1024 * 1024 - injectWaitPut = 0 isWindows = runtime.GOOS == "windows" mb float64 = 1024.0 * 1024.0 ) @@ -835,15 +834,14 @@ func (sfa *snowflakeFileTransferAgent) uploadFilesSequential(fileMetas []*fileMe } continue } else if res.resStatus == renewPresignedURL { - sfa.updateFileMetadataWithPresignedURL() + if err = sfa.updateFileMetadataWithPresignedURL(); err != nil { + return err + } continue } sfa.results = append(sfa.results, res) idx++ - if injectWaitPut > 0 { - time.Sleep(injectWaitPut) - } } return nil } @@ -950,7 +948,9 @@ func (sfa *snowflakeFileTransferAgent) downloadFilesParallel(fileMetas []*fileMe for _, result := range retryMeta { if result.resStatus == renewPresignedURL { - sfa.updateFileMetadataWithPresignedURL() + if err = sfa.updateFileMetadataWithPresignedURL(); err != nil { + return err + } break } } @@ -1208,7 +1208,10 @@ func (spp *snowflakeProgressPercentage) updateProgress(filename string, startTim if status != "" { block := int(math.Round(float64(barLength) * progress)) text := fmt.Sprintf("\r%v(%.2fMB): [%v] %.2f%% %v ", filename, totalSize, strings.Repeat("#", block)+strings.Repeat("-", barLength-block), progress*100, status) - (*outputStream).Write([]byte(text)) + _, err := (*outputStream).Write([]byte(text)) + if err != nil { + logger.Warn("cannot write status of progress. %v", err) + } } return progress == 1.0 } diff --git a/file_transfer_agent_test.go b/file_transfer_agent_test.go index b4b0aa2c8..1540a15c3 100644 --- a/file_transfer_agent_test.go +++ b/file_transfer_agent_test.go @@ -185,7 +185,9 @@ func TestUnitDownloadWithInvalidLocalPath(t *testing.T) { if err != nil { t.Error(err) } - defer assertNilF(t, os.RemoveAll(tmpDir)) + defer func() { + assertNilF(t, os.RemoveAll(tmpDir)) + }() testData := filepath.Join(tmpDir, "data.txt") f, err := os.Create(testData) if err != nil { @@ -774,7 +776,9 @@ func TestCustomTmpDirPath(t *testing.T) { if err != nil { t.Fatalf("cannot create temp directory: %v", err) } - defer assertNilF(t, os.RemoveAll(tmpDir)) + defer func() { + assertNilF(t, os.RemoveAll(tmpDir)) + }() uploadFile := filepath.Join(tmpDir, "data.txt") f, err := os.Create(uploadFile) if err != nil { @@ -849,7 +853,9 @@ func TestReadonlyTmpDirPathShouldFail(t *testing.T) { if err != nil { t.Fatalf("cannot create temp directory: %v", err) } - defer assertNilF(t, os.RemoveAll(tmpDir)) + defer func() { + assertNilF(t, os.RemoveAll(tmpDir)) + }() uploadFile := filepath.Join(tmpDir, "data.txt") f, err := os.Create(uploadFile) @@ -860,11 +866,13 @@ func TestReadonlyTmpDirPathShouldFail(t *testing.T) { assertNilF(t, err) assertNilF(t, f.Close()) - err = os.Chmod(tmpDir, 0400) + err = os.Chmod(tmpDir, 0500) if err != nil { t.Fatalf("cannot mark directory as readonly: %v", err) } - defer assertNilF(t, os.Chmod(tmpDir, 0600)) + defer func() { + assertNilF(t, os.Chmod(tmpDir, 0700)) + }() uploadMeta := &fileMetadata{ name: "data.txt.gz", @@ -912,7 +920,9 @@ func testUploadDownloadOneFile(t *testing.T, isStream bool) { if err != nil { t.Fatalf("cannot create temp directory: %v", err) } - defer assertNilF(t, os.RemoveAll(tmpDir)) + defer func() { + assertNilF(t, os.RemoveAll(tmpDir)) + }() uploadFile := filepath.Join(tmpDir, "data.txt") f, err := os.Create(uploadFile) if err != nil { @@ -990,7 +1000,9 @@ func testUploadDownloadOneFile(t *testing.T, isStream bool) { if err != nil { t.Fatal(err) } - defer assertNilF(t, os.Remove("download.txt")) + defer func() { + assertNilF(t, os.Remove("download.txt")) + }() if downloadMeta.resStatus != downloaded { t.Fatalf("failed to download file") } diff --git a/file_util.go b/file_util.go index 890242e9e..76402b355 100644 --- a/file_util.go +++ b/file_util.go @@ -40,9 +40,9 @@ func (util *snowflakeFileUtil) compressFileWithGzipFromStream(srcStream **bytes. return &c, c.Len(), nil } -func (util *snowflakeFileUtil) compressFileWithGzip(fileName string, tmpDir string) (string, int64, error) { +func (util *snowflakeFileUtil) compressFileWithGzip(fileName string, tmpDir string) (gzipFileName string, size int64, err error) { basename := baseName(fileName) - gzipFileName := filepath.Join(tmpDir, basename+"_c.gz") + gzipFileName = filepath.Join(tmpDir, basename+"_c.gz") fr, err := os.Open(fileName) if err != nil { @@ -92,7 +92,7 @@ func (util *snowflakeFileUtil) getDigestAndSizeForStream(stream **bytes.Buffer) return base64.StdEncoding.EncodeToString(m.Sum(nil)), int64((*stream).Len()), nil } -func (util *snowflakeFileUtil) getDigestAndSizeForFile(fileName string) (string, int64, error) { +func (util *snowflakeFileUtil) getDigestAndSizeForFile(fileName string) (digest string, size int64, err error) { f, err := os.Open(fileName) if err != nil { return "", 0, err diff --git a/gcs_storage_client.go b/gcs_storage_client.go index 78508b39a..0627f6122 100644 --- a/gcs_storage_client.go +++ b/gcs_storage_client.go @@ -225,13 +225,17 @@ func (util *snowflakeGcsClient) uploadFile( return err } if resp.StatusCode != http.StatusOK { - meta.lastError = fmt.Errorf("%v", resp.Status) if resp.StatusCode == 403 || resp.StatusCode == 408 || resp.StatusCode == 429 || resp.StatusCode == 500 || resp.StatusCode == 503 { + meta.lastError = fmt.Errorf("%v", resp.Status) meta.resStatus = needRetry } else if accessToken == "" && resp.StatusCode == 400 && meta.lastError == nil { + meta.lastError = fmt.Errorf("%v", resp.Status) meta.resStatus = renewPresignedURL } else if accessToken != "" && util.isTokenExpired(resp) { + meta.lastError = fmt.Errorf("%v", resp.Status) meta.resStatus = renewToken + } else { + meta.lastError = fmt.Errorf("%v", resp.Status) } return meta.lastError } @@ -299,15 +303,20 @@ func (util *snowflakeGcsClient) nativeDownloadFile( return err } if resp.StatusCode != http.StatusOK { - meta.lastError = fmt.Errorf("%v", resp.Status) if resp.StatusCode == 403 || resp.StatusCode == 408 || resp.StatusCode == 429 || resp.StatusCode == 500 || resp.StatusCode == 503 { + meta.lastError = fmt.Errorf("%v", resp.Status) meta.resStatus = needRetry } else if resp.StatusCode == 404 { + meta.lastError = fmt.Errorf("%v", resp.Status) meta.resStatus = notFoundFile } else if accessToken == "" && resp.StatusCode == 400 && meta.lastError == nil { + meta.lastError = fmt.Errorf("%v", resp.Status) meta.resStatus = renewPresignedURL } else if accessToken != "" && util.isTokenExpired(resp) { + meta.lastError = fmt.Errorf("%v", resp.Status) meta.resStatus = renewToken + } else { + meta.lastError = fmt.Errorf("%v", resp.Status) } return meta.lastError } diff --git a/htap_test.go b/htap_test.go index fed7c6095..8a8f82ccd 100644 --- a/htap_test.go +++ b/htap_test.go @@ -347,7 +347,9 @@ func TestHybridTablesE2E(t *testing.T) { testDb2 := fmt.Sprintf("hybrid_db_test_%v_2", runID) runSnowflakeConnTest(t, func(sct *SCTest) { dbQuery := sct.mustQuery("SELECT CURRENT_DATABASE()", nil) - defer assertNilF(t, dbQuery.Close()) + defer func() { + assertNilF(t, dbQuery.Close()) + }() currentDb := make([]driver.Value, 1) assertNilF(t, dbQuery.Next(currentDb)) defer func() { @@ -362,7 +364,9 @@ func TestHybridTablesE2E(t *testing.T) { sct.mustExec("INSERT INTO test_hybrid_table VALUES (1, 'a')", nil) rows := sct.mustQuery("SELECT * FROM test_hybrid_table", nil) - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() row := make([]driver.Value, 2) assertNilF(t, rows.Next(row)) if row[0] != "1" || row[1] != "a" { @@ -371,7 +375,9 @@ func TestHybridTablesE2E(t *testing.T) { sct.mustExec("INSERT INTO test_hybrid_table VALUES (2, 'b')", nil) rows2 := sct.mustQuery("SELECT * FROM test_hybrid_table", nil) - defer assertNilF(t, rows2.Close()) + defer func() { + assertNilF(t, rows2.Close()) + }() assertNilF(t, rows2.Next(row)) if row[0] != "1" || row[1] != "a" { t.Errorf("expected 1, got %v and expected a, got %v", row[0], row[1]) @@ -390,7 +396,9 @@ func TestHybridTablesE2E(t *testing.T) { sct.mustExec("INSERT INTO test_hybrid_table_2 VALUES (3, 'c')", nil) rows := sct.mustQuery("SELECT * FROM test_hybrid_table_2", nil) - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() row := make([]driver.Value, 2) assertNilF(t, rows.Next(row)) if row[0] != "3" || row[1] != "c" { @@ -406,7 +414,9 @@ func TestHybridTablesE2E(t *testing.T) { sct.mustExec("INSERT INTO test_hybrid_table VALUES (4, 'd')", nil) rows := sct.mustQuery("SELECT * FROM test_hybrid_table", nil) - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() if len(sct.sc.queryContextCache.entries) != 3 { t.Errorf("expected three entries in query context cache, got: %v", sct.sc.queryContextCache.entries) } @@ -567,7 +577,9 @@ func TestConnIsCleanAfterClose(t *testing.T) { var dbName2 string rows2 := dbt2.mustQuery("SELECT CURRENT_DATABASE()") - defer assertNilF(t, rows2.Close()) + defer func() { + assertNilF(t, rows2.Close()) + }() rows2.Next() assertNilF(t, rows2.Scan(&dbName2)) diff --git a/local_storage_client_test.go b/local_storage_client_test.go index 3fc2ab0fb..ac12331ba 100644 --- a/local_storage_client_test.go +++ b/local_storage_client_test.go @@ -17,7 +17,9 @@ func TestLocalUpload(t *testing.T) { if err != nil { t.Error(err) } - defer assertNilF(t, os.RemoveAll(tmpDir)) + defer func() { + assertNilF(t, os.RemoveAll(tmpDir)) + }() fname := filepath.Join(tmpDir, "test_put_get.txt.gz") originalContents := "123,test1\n456,test2\n" @@ -111,7 +113,9 @@ func TestDownloadLocalFile(t *testing.T) { if err != nil { t.Error(err) } - defer assertNilF(t, os.RemoveAll(tmpDir)) + defer func() { + assertNilF(t, os.RemoveAll(tmpDir)) + }() fname := filepath.Join(tmpDir, "test_put_get.txt.gz") originalContents := "123,test1\n456,test2\n" diff --git a/put_get_test.go b/put_get_test.go index 1f8607477..4aed419d2 100644 --- a/put_get_test.go +++ b/put_get_test.go @@ -32,11 +32,15 @@ func TestPutError(t *testing.T) { if err != nil { t.Error(err) } - defer assertNilF(t, f.Close()) + defer func() { + assertNilF(t, f.Close()) + }() _, err = f.WriteString("test1") assertNilF(t, err) assertNilF(t, os.Chmod(file1, 0000)) - defer assertNilF(t, os.Chmod(file1, 0644)) + defer func() { + assertNilF(t, os.Chmod(file1, 0644)) + }() data := &execResponseData{ Command: string(uploadCommand), @@ -212,7 +216,9 @@ func TestPutLocalFile(t *testing.T) { var s0, s1, s2, s3, s4, s5, s6, s7, s8, s9 string rows := dbt.mustQuery("copy into gotest_putget_t1") - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() for rows.Next() { assertNilF(t, rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7, &s8, &s9)) if s1 != "LOADED" { @@ -221,7 +227,9 @@ func TestPutLocalFile(t *testing.T) { } rows2 := dbt.mustQuery("select count(*) from gotest_putget_t1") - defer assertNilF(t, rows2.Close()) + defer func() { + assertNilF(t, rows2.Close()) + }() var i int if rows2.Next() { assertNilF(t, rows2.Scan(&i)) @@ -231,7 +239,9 @@ func TestPutLocalFile(t *testing.T) { } rows3 := dbt.mustQuery(`select STATUS from information_schema .load_history where table_name='gotest_putget_t1'`) - defer assertNilF(t, rows3.Close()) + defer func() { + assertNilF(t, rows3.Close()) + }() if rows3.Next() { assertNilF(t, rows3.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7, &s8, &s9)) if s1 != "LOADED" { @@ -252,7 +262,9 @@ func TestPutGetWithAutoCompressFalse(t *testing.T) { _, err = f.WriteString(originalContents) assertNilF(t, err) assertNilF(t, f.Sync()) - defer assertNilF(t, f.Close()) + defer func() { + assertNilF(t, f.Close()) + }() runDBTest(t, func(dbt *DBTest) { dbt.mustExec("rm @~/test_put_uncompress_file") @@ -263,7 +275,9 @@ func TestPutGetWithAutoCompressFalse(t *testing.T) { dbt.mustExec(sqlText) defer dbt.mustExec("rm @~/test_put_uncompress_file") rows := dbt.mustQuery("ls @~/test_put_uncompress_file") - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() var file, s1, s2, s3 string if rows.Next() { err = rows.Scan(&file, &s1, &s2, &s3) @@ -279,7 +293,9 @@ func TestPutGetWithAutoCompressFalse(t *testing.T) { sql := fmt.Sprintf("get @~/test_put_uncompress_file/data.txt 'file://%v'", tmpDir) sqlText = strings.ReplaceAll(sql, "\\", "\\\\") rows2 := dbt.mustQueryContext(ctx, sqlText) - defer assertNilF(t, rows2.Close()) + defer func() { + assertNilF(t, rows2.Close()) + }() for rows2.Next() { err = rows2.Scan(&file, &s1, &s2, &s3) assertNilE(t, err) @@ -435,7 +451,9 @@ func testPutGet(t *testing.T, isStream bool) { if err != nil { t.Error(err) } - defer assertNilF(t, fileStream.Close()) + defer func() { + assertNilF(t, fileStream.Close()) + }() var sqlText string var rows *RowsExtended @@ -450,7 +468,9 @@ func testPutGet(t *testing.T, isStream bool) { sql, strings.ReplaceAll(fname, "\\", "\\\\"), tableName) rows = dbt.mustQuery(sqlText) } - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() var s0, s1, s2, s3, s4, s5, s6, s7 string assertTrueF(t, rows.Next(), "expected new rows") @@ -478,7 +498,9 @@ func testPutGet(t *testing.T, isStream bool) { sql = fmt.Sprintf("get @%%%v 'file://%v'", tableName, tmpDir) sqlText = strings.ReplaceAll(sql, "\\", "\\\\") rows2 := dbt.mustQueryContext(ctx, sqlText) - defer assertNilF(t, rows2.Close()) + defer func() { + assertNilF(t, rows2.Close()) + }() for rows2.Next() { if err = rows2.Scan(&s0, &s1, &s2, &s3); err != nil { t.Error(err) @@ -501,7 +523,9 @@ func testPutGet(t *testing.T, isStream bool) { if isStream { gz, err := gzip.NewReader(&streamBuf) assertNilE(t, err) - defer assertNilF(t, gz.Close()) + defer func() { + assertNilF(t, gz.Close()) + }() for { c := make([]byte, defaultChunkBufferSize) if n, err := gz.Read(c); err != nil { @@ -522,11 +546,15 @@ func testPutGet(t *testing.T, isStream bool) { fileName := files[0] f, err := os.Open(fileName) assertNilE(t, err) - defer assertNilF(t, f.Close()) + defer func() { + assertNilF(t, f.Close()) + }() gz, err := gzip.NewReader(f) assertNilE(t, err) - defer assertNilF(t, gz.Close()) + defer func() { + assertNilF(t, gz.Close()) + }() for { c := make([]byte, defaultChunkBufferSize) @@ -553,7 +581,9 @@ func TestPutGetGcsDownscopedCredential(t *testing.T) { if err != nil { t.Error(err) } - defer assertNilF(t, os.RemoveAll(tmpDir)) + defer func() { + assertNilF(t, os.RemoveAll(tmpDir)) + }() fname := filepath.Join(tmpDir, "test_put_get.txt.gz") originalContents := "123,test1\n456,test2\n" tableName := randomString(5) @@ -588,7 +618,9 @@ func TestPutGetGcsDownscopedCredential(t *testing.T) { sqlText = fmt.Sprintf( sql, strings.ReplaceAll(fname, "\\", "\\\\"), tableName) rows = dbt.mustQuery(sqlText) - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() var s0, s1, s2, s3, s4, s5, s6, s7 string if rows.Next() { @@ -612,7 +644,9 @@ func TestPutGetGcsDownscopedCredential(t *testing.T) { sql = fmt.Sprintf("get @%%%v 'file://%v'", tableName, tmpDir) sqlText = strings.ReplaceAll(sql, "\\", "\\\\") rows2 := dbt.mustQuery(sqlText) - defer assertNilF(t, rows2.Close()) + defer func() { + assertNilF(t, rows2.Close()) + }() for rows2.Next() { if err = rows2.Scan(&s0, &s1, &s2, &s3); err != nil { t.Error(err) @@ -678,7 +712,9 @@ func TestPutGetLargeFile(t *testing.T) { dbt.mustExec(sqlText) defer dbt.mustExec("rm @~/test_put_largefile") rows := dbt.mustQuery("ls @~/test_put_largefile") - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() var file, s1, s2, s3 string if rows.Next() { err = rows.Scan(&file, &s1, &s2, &s3) @@ -696,7 +732,9 @@ func TestPutGetLargeFile(t *testing.T) { sql := fmt.Sprintf("get @%v 'file://%v'", "~/test_put_largefile/largefile.txt.gz", t.TempDir()) sqlText = strings.ReplaceAll(sql, "\\", "\\\\") rows2 := dbt.mustQueryContext(ctx, sqlText) - defer assertNilF(t, rows2.Close()) + defer func() { + assertNilF(t, rows2.Close()) + }() for rows2.Next() { err = rows2.Scan(&file, &s1, &s2, &s3) assertNilE(t, err) @@ -712,7 +750,9 @@ func TestPutGetLargeFile(t *testing.T) { var contents string gz, err := gzip.NewReader(&streamBuf) assertNilE(t, err) - defer assertNilF(t, gz.Close()) + defer func() { + assertNilF(t, gz.Close()) + }() for { c := make([]byte, defaultChunkBufferSize) if n, err := gz.Read(c); err != nil { @@ -768,7 +808,9 @@ func TestPutGetMaxLOBSize(t *testing.T) { defer dbt.mustExec("drop table " + tableName) fileStream, err := os.Open(fname) assertNilF(t, err) - defer assertNilF(t, fileStream.Close()) + defer func() { + assertNilF(t, fileStream.Close()) + }() // test PUT command var sqlText string @@ -777,7 +819,9 @@ func TestPutGetMaxLOBSize(t *testing.T) { sqlText = fmt.Sprintf( sql, strings.ReplaceAll(fname, "\\", "\\\\"), tableName) rows = dbt.mustQuery(sqlText) - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() var s0, s1, s2, s3, s4, s5, s6, s7 string assertTrueF(t, rows.Next(), "expected new rows") @@ -800,7 +844,9 @@ func TestPutGetMaxLOBSize(t *testing.T) { sql = fmt.Sprintf("get @%%%v 'file://%v'", tableName, tmpDir) sqlText = strings.ReplaceAll(sql, "\\", "\\\\") rows2 := dbt.mustQuery(sqlText) - defer assertNilF(t, rows2.Close()) + defer func() { + assertNilF(t, rows2.Close()) + }() for rows2.Next() { err = rows2.Scan(&s0, &s1, &s2, &s3) assertNilE(t, err) @@ -817,11 +863,15 @@ func TestPutGetMaxLOBSize(t *testing.T) { f, err := os.Open(fileName) assertNilE(t, err) - defer assertNilF(t, f.Close()) + defer func() { + assertNilF(t, f.Close()) + }() gz, err := gzip.NewReader(f) assertNilE(t, err) - defer assertNilF(t, gz.Close()) + defer func() { + assertNilF(t, gz.Close()) + }() var contents string for { c := make([]byte, defaultChunkBufferSize) diff --git a/put_get_user_stage_test.go b/put_get_user_stage_test.go index 8a7c1e09c..7fe927d4e 100644 --- a/put_get_user_stage_test.go +++ b/put_get_user_stage_test.go @@ -82,7 +82,9 @@ func putGetUserStage(t *testing.T, numberOfFiles int, numberOfLines int, isStrea dbt.mustExec(fmt.Sprintf("copy into %v from @%v", dbname, stageName)) rows := dbt.mustQuery("select count(*) from " + dbname) - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() var cnt string if rows.Next() { assertNilF(t, rows.Scan(&cnt)) @@ -130,7 +132,9 @@ func TestPutLoadFromUserStage(t *testing.T) { rows := dbt.mustQuery(fmt.Sprintf(`copy into gotest_putget_t2 from @%v file_format = (field_delimiter = '|' error_on_column_count_mismatch =false) purge=true`, data.stage)) - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() var s0, s1, s2, s3, s4, s5 string var s6, s7, s8, s9 interface{} orders100 := fmt.Sprintf("s3://%v/%v/orders_100.csv.gz", diff --git a/put_get_with_aws_test.go b/put_get_with_aws_test.go index 97430d1e7..226f37a5a 100644 --- a/put_get_with_aws_test.go +++ b/put_get_with_aws_test.go @@ -70,7 +70,9 @@ func TestLoadS3(t *testing.T) { AWS_SECRET_KEY='%v') file_format=(skip_header=1 null_if=('') field_optionally_enclosed_by='\"')`, data.awsAccessKeyID, data.awsSecretAccessKey)) - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() var s0, s1, s2, s3, s4, s5, s6, s7, s8, s9 string cnt := 0 for rows.Next() { @@ -140,7 +142,9 @@ func TestPutWithInvalidToken(t *testing.T) { if err != nil { t.Error(err) } - defer assertNilF(t, f.Close()) + defer func() { + assertNilF(t, f.Close()) + }() uploader := manager.NewUploader(client) if _, err = uploader.Upload(context.Background(), &s3.PutObjectInput{ Bucket: &s3Loc.bucketName, @@ -270,7 +274,9 @@ func TestPutGetAWSStage(t *testing.T) { sql := "put 'file://%v' @~/%v auto_compress=false" sqlText := fmt.Sprintf(sql, strings.ReplaceAll(fname, "\\", "\\\\"), stageName) rows := dbt.mustQuery(sqlText) - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() var s0, s1, s2, s3, s4, s5, s6, s7 string if rows.Next() { @@ -285,7 +291,9 @@ func TestPutGetAWSStage(t *testing.T) { sql = fmt.Sprintf("get @~/%v 'file://%v'", stageName, tmpDir) sqlText = strings.ReplaceAll(sql, "\\", "\\\\") rows = dbt.mustQuery(sqlText) - defer assertNilF(t, rows.Close()) + defer func() { + assertNilF(t, rows.Close()) + }() for rows.Next() { if err = rows.Scan(&s0, &s1, &s2, &s3); err != nil { t.Error(err) @@ -314,7 +322,9 @@ func TestPutGetAWSStage(t *testing.T) { if err != nil { t.Error(err) } - defer assertNilF(t, f.Close()) + defer func() { + assertNilF(t, f.Close()) + }() gz, err := gzip.NewReader(f) if err != nil { t.Error(err) diff --git a/rows_test.go b/rows_test.go index f9f6cd6dd..d365abdf5 100644 --- a/rows_test.go +++ b/rows_test.go @@ -505,7 +505,9 @@ func TestLocationChangesAfterAlterSession(t *testing.T) { dbt.mustExec("ALTER SESSION SET TIMEZONE = 'Europe/Warsaw'") dbt.mustExec("INSERT INTO location_timestamp_ltz VALUES('2023-08-09 10:00:00')") rows1 := dbt.mustQuery("SELECT * FROM location_timestamp_ltz") - defer assertNilF(t, rows1.Close()) + defer func() { + assertNilF(t, rows1.Close()) + }() if !rows1.Next() { t.Fatalf("cannot read a record") } @@ -516,7 +518,9 @@ func TestLocationChangesAfterAlterSession(t *testing.T) { } dbt.mustExec("ALTER SESSION SET TIMEZONE = 'Pacific/Honolulu'") rows2 := dbt.mustQuery("SELECT * FROM location_timestamp_ltz") - defer assertNilF(t, rows2.Close()) + defer func() { + assertNilF(t, rows2.Close()) + }() if !rows2.Next() { t.Fatalf("cannot read a record") } diff --git a/statement_test.go b/statement_test.go index 2133b5c06..f60a8b5d2 100644 --- a/statement_test.go +++ b/statement_test.go @@ -264,7 +264,7 @@ func TestAsyncFailQueryId(t *testing.T) { t.Error("should be in progress") } // Wait for the query to complete - assertNilF(t, rows.Next(nil)) + assertNotNilE(t, rows.Next(nil)) if rows.(SnowflakeRows).GetStatus() != QueryFailed { t.Error("should have failed") } diff --git a/storage_client.go b/storage_client.go index f1be2b456..60ebdb584 100644 --- a/storage_client.go +++ b/storage_client.go @@ -145,7 +145,8 @@ func (rsu *remoteStorageUtil) uploadOneFileWithRetry(meta *fileMetadata) error { if meta.resStatus == uploaded || meta.resStatus == skipped { for j := 0; j < 10; j++ { status := meta.resStatus - utilClass.getFileHeader(meta, meta.dstFileName) + _, err := utilClass.getFileHeader(meta, meta.dstFileName) + logger.Infof("error while getting file %v header. %v", meta.dstFileSize, err) // check file header status and verify upload/skip if meta.resStatus == notFoundFile { time.Sleep(time.Second) // wait 1 second diff --git a/transaction_test.go b/transaction_test.go index 1c4c3897a..8410f3e14 100644 --- a/transaction_test.go +++ b/transaction_test.go @@ -16,7 +16,9 @@ func TestTransactionOptions(t *testing.T) { var err error conn := openConn(t) - defer assertNilF(t, conn.Close()) + defer func() { + assertNilF(t, conn.Close()) + }() tx, err = conn.BeginTx(context.Background(), &sql.TxOptions{}) if err != nil { @@ -45,7 +47,9 @@ func TestTransactionContext(t *testing.T) { var err error conn := openConn(t) - defer assertNilF(t, conn.Close()) + defer func() { + assertNilF(t, conn.Close()) + }() ctx := context.Background() @@ -60,7 +64,6 @@ func TestTransactionContext(t *testing.T) { if err != nil { t.Fatal(err) } - defer assertNilF(t, tx.Rollback()) _, err = tx.ExecContext(ctx, "SELECT SYSTEM$WAIT(10, 'SECONDS')") if err != nil {