Skip to content

Commit

Permalink
Fix defers
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-pfus committed Oct 30, 2024
1 parent 1db343c commit b1382de
Show file tree
Hide file tree
Showing 23 changed files with 317 additions and 128 deletions.
10 changes: 9 additions & 1 deletion .golangci.yml
Original file line number Diff line number Diff line change
@@ -1,2 +1,10 @@
run:
tests: true
tests: true

issues:
exclude-rules:
# we own this driver so we need to test this functions
- path: statement_test.go
text: "Drivers should implement StmtExecContext instead"
- path: statement_test.go
text: "Drivers should implement StmtQueryContext instead"
10 changes: 4 additions & 6 deletions aaa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,13 @@ import (
func TestShowServerVersion(t *testing.T) {
runDBTest(t, func(dbt *DBTest) {
rows := dbt.mustQuery("SELECT CURRENT_VERSION()")
defer func(rows *RowsExtended) {
err := rows.Close()
assertNilF(t, err)
}(rows)
defer func() {
assertNilF(t, rows.Close())
}()

var version string
rows.Next()
err := rows.Scan(&version)
assertNilF(t, err)
assertNilF(t, rows.Scan(&version))
println(version)
})
}
20 changes: 9 additions & 11 deletions authexternalbrowser.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,25 +304,23 @@ func doAuthenticateByExternalBrowser(
if n < bufSize {
// We successfully read all data
s := string(buf.Bytes()[:total])
encodedSamlResponse, err = getTokenFromResponse(s)
if err != nil {
errChan <- err
}
encodedSamlResponse, errAccept = getTokenFromResponse(s)
break
}
buf.Grow(bufSize)
}
if encodedSamlResponse != "" {
httpResponse, err := buildResponse(application)
if err != nil {
errChan <- err
return
if err != nil && errAccept == nil {
errAccept = err
}
_, err = c.Write(httpResponse.Bytes())
errChan <- err
return
if _, err = c.Write(httpResponse.Bytes()); err != nil {
errAccept = err
}
}
if err := c.Close(); err != nil {
logger.Warnf("error while closing browser connection. %v", err)
}
err = c.Close()
encodedSamlResponseChan <- encodedSamlResponse
errChan <- errAccept
}(conn)
Expand Down
88 changes: 66 additions & 22 deletions bindings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ func TestBindingFloat64(t *testing.T) {
dbt.mustExec(fmt.Sprintf("CREATE OR REPLACE TABLE test (id int, value %v)", v))
dbt.mustExec("INSERT INTO test VALUES (1, ?)", expected)
rows = dbt.mustQuery("SELECT value FROM test WHERE id = ?", 1)
defer assertNilF(t, rows.Close())
defer func() {
assertNilF(t, rows.Close())
}()
if rows.Next() {
assertNilF(t, rows.Scan(&out))
if expected != out {
Expand Down Expand Up @@ -200,12 +202,16 @@ func TestBindingTimestampTZ(t *testing.T) {
if err != nil {
dbt.Fatal(err.Error())
}
defer assertNilF(t, stmt.Close())
defer func() {
assertNilF(t, stmt.Close())
}()
if _, err = stmt.Exec(DataTypeTimestampTz, expected); err != nil {
dbt.Fatal(err)
}
rows := dbt.mustQuery("SELECT tz FROM tztest WHERE id=?", 1)
defer assertNilF(t, rows.Close())
defer func() {
assertNilF(t, rows.Close())
}()
var v time.Time
if rows.Next() {
assertNilF(t, rows.Scan(&v))
Expand Down Expand Up @@ -251,7 +257,9 @@ func TestBindingTimePtrInStruct(t *testing.T) {
}

rows := dbt.mustQuery("SELECT tz FROM timeStructTest WHERE id=?", &expectedID)
defer assertNilF(t, rows.Close())
defer func() {
assertNilF(t, rows.Close())
}()
var v time.Time
if rows.Next() {
assertNilF(t, rows.Scan(&v))
Expand Down Expand Up @@ -298,7 +306,9 @@ func TestBindingTimeInStruct(t *testing.T) {
}

rows := dbt.mustQuery("SELECT tz FROM timeStructTest WHERE id=?", &expectedID)
defer assertNilF(t, rows.Close())
defer func() {
assertNilF(t, rows.Close())
}()
var v time.Time
if rows.Next() {
assertNilF(t, rows.Scan(&v))
Expand All @@ -318,7 +328,9 @@ func TestBindingInterface(t *testing.T) {
runDBTest(t, func(dbt *DBTest) {
rows := dbt.mustQueryContext(
WithHigherPrecision(context.Background()), selectVariousTypes)
defer assertNilF(t, rows.Close())
defer func() {
assertNilF(t, rows.Close())
}()
if !rows.Next() {
dbt.Error("failed to query")
}
Expand All @@ -344,7 +356,9 @@ func TestBindingInterface(t *testing.T) {
func TestBindingInterfaceString(t *testing.T) {
runDBTest(t, func(dbt *DBTest) {
rows := dbt.mustQuery(selectVariousTypes)
defer assertNilF(t, rows.Close())
defer func() {
assertNilF(t, rows.Close())
}()
if !rows.Next() {
dbt.Error("failed to query")
}
Expand Down Expand Up @@ -381,7 +395,9 @@ func TestBulkArrayBindingInterfaceNil(t *testing.T) {
Array(&nilArray, TimestampTZType), Array(&nilArray, DateType),
Array(&nilArray, TimeType))
rows := dbt.mustQuery(selectAllSQL)
defer assertNilF(t, rows.Close())
defer func() {
assertNilF(t, rows.Close())
}()

var v0 sql.NullInt32
var v1 sql.NullFloat64
Expand Down Expand Up @@ -464,7 +480,9 @@ func TestBulkArrayBindingInterface(t *testing.T) {
dbt.mustExec(insertSQLBulkArray, Array(&intArray), Array(&fltArray),
Array(&boolArray), Array(&strArray), Array(&byteArray), Array(&int64Array))
rows := dbt.mustQuery(selectAllSQLBulkArray)
defer assertNilF(t, rows.Close())
defer func() {
assertNilF(t, rows.Close())
}()

var v0 sql.NullInt32
var v1 sql.NullFloat64
Expand Down Expand Up @@ -567,7 +585,9 @@ func TestBulkArrayBindingInterfaceDateTimeTimestamp(t *testing.T) {
Array(&tmArray, TimeType))

rows := dbt.mustQuery(selectAllSQLBulkArrayDateTimeTimestamp)
defer assertNilF(t, rows.Close())
defer func() {
assertNilF(t, rows.Close())
}()

var v0, v1, v2, v3, v4 sql.NullTime

Expand Down Expand Up @@ -674,7 +694,9 @@ func testBindingArray(t *testing.T, bulk bool) {
Array(&tzArray, TimestampTZType), Array(&dtArray, DateType),
Array(&tmArray, TimeType))
rows := dbt.mustQuery(selectAllSQL)
defer assertNilF(t, rows.Close())
defer func() {
assertNilF(t, rows.Close())
}()

var v0 int
var v1 float64
Expand Down Expand Up @@ -754,7 +776,9 @@ func TestBulkArrayBinding(t *testing.T) {
}
dbt.mustExec(fmt.Sprintf("insert into %v values (?, ?, ?, ?, ?, ?, ?, ?)", dbname), Array(&intArr), Array(&strArr), Array(&ltzArr, TimestampLTZType), Array(&tzArr, TimestampTZType), Array(&ntzArr, TimestampNTZType), Array(&dateArr, DateType), Array(&timeArr, TimeType), Array(&binArr))
rows := dbt.mustQuery("select * from " + dbname + " order by c1")
defer assertNilF(t, rows.Close())
defer func() {
assertNilF(t, rows.Close())
}()
cnt := 0
var i int
var s string
Expand Down Expand Up @@ -800,7 +824,9 @@ func TestBulkArrayBindingTimeWithPrecision(t *testing.T) {
}
dbt.mustExec(fmt.Sprintf("insert into %v values (?, ?, ?, ?)", dbname), Array(&secondsArr, TimeType), Array(&millisecondsArr, TimeType), Array(&microsecondsArr, TimeType), Array(&nanosecondsArr, TimeType))
rows := dbt.mustQuery("select * from " + dbname)
defer assertNilF(t, rows.Close())
defer func() {
assertNilF(t, rows.Close())
}()
cnt := 0
var s, ms, us, ns time.Time
for rows.Next() {
Expand Down Expand Up @@ -839,7 +865,9 @@ func TestBulkArrayMultiPartBinding(t *testing.T) {
fmt.Sprintf("INSERT INTO %s VALUES (?)", tempTableName),
Array(&randomStrings))
rows := dbt.mustQuery("select count(*) from " + tempTableName)
defer assertNilF(t, rows.Close())
defer func() {
assertNilF(t, rows.Close())
}()
if rows.Next() {
var count int
if err := rows.Scan(&count); err != nil {
Expand All @@ -849,7 +877,9 @@ func TestBulkArrayMultiPartBinding(t *testing.T) {
}

rows := dbt.mustQuery("select count(*) from " + tempTableName)
defer assertNilF(t, rows.Close())
defer func() {
assertNilF(t, rows.Close())
}()
if rows.Next() {
var count int
if err := rows.Scan(&count); err != nil {
Expand Down Expand Up @@ -878,7 +908,9 @@ func TestBulkArrayMultiPartBindingInt(t *testing.T) {
}

rows := dbt.mustQuery("select * from binding_test order by c1")
defer assertNilF(t, rows.Close())
defer func() {
assertNilF(t, rows.Close())
}()
cnt := startNum
var i int
for rows.Next() {
Expand Down Expand Up @@ -926,7 +958,9 @@ func TestBulkArrayMultiPartBindingWithNull(t *testing.T) {
}

rows := dbt.mustQuery("select * from binding_test order by c1,c2")
defer assertNilF(t, rows.Close())
defer func() {
assertNilF(t, rows.Close())
}()
cnt := startNum
var i sql.NullInt32
var s sql.NullString
Expand Down Expand Up @@ -1007,7 +1041,9 @@ func TestFunctionParameters(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer assertNilF(t, rows.Close())
defer func() {
assertNilF(t, rows.Close())
}()
if rows.Err() != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -1107,7 +1143,9 @@ func TestVariousBindingModes(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer assertNilF(t, rows.Close())
defer func() {
assertNilF(t, rows.Close())
}()
if !rows.Next() {
t.Fatal("Expected to return a row")
}
Expand Down Expand Up @@ -1155,7 +1193,9 @@ func testLOBRetrieval(t *testing.T, useArrowFormat bool) {
t.Run(fmt.Sprintf("testLOB_%v_useArrowFormat=%v", strconv.Itoa(testSize), strconv.FormatBool(useArrowFormat)), func(t *testing.T) {
rows, err := dbt.query(fmt.Sprintf("SELECT randstr(%v, 124)", testSize))
assertNilF(t, err)
defer assertNilF(t, rows.Close())
defer func() {
assertNilF(t, rows.Close())
}()
assertTrueF(t, rows.Next(), fmt.Sprintf("no rows returned for the LOB size %v", testSize))

// retrieve the result
Expand Down Expand Up @@ -1186,7 +1226,9 @@ func TestMaxLobSize(t *testing.T) {
dbt.mustExec(enableLargeVarcharAndBinary)
rows, err := dbt.query("select randstr(20000000, random())")
assertNilF(t, err)
defer assertNilF(t, rows.Close())
defer func() {
assertNilF(t, rows.Close())
}()
})
})
}
Expand Down Expand Up @@ -1265,7 +1307,9 @@ func testInsertLOBData(t *testing.T, useArrowFormat bool, isLiteral bool) {
}
rows, err := dbt.query("SELECT * FROM lob_test_table")
assertNilF(t, err)
defer assertNilF(t, rows.Close())
defer func() {
assertNilF(t, rows.Close())
}()
assertTrueF(t, rows.Next(), fmt.Sprintf("%s: no rows returned", tc.testDesc))

err = rows.Scan(&c1, &c2, &c3)
Expand Down
10 changes: 7 additions & 3 deletions chunk_downloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,11 @@ func TestWithArrowBatchesWhenQueryReturnsSomeRowsInGivenFormatUsingNativeGoSQLIn
return err
})
assertNilF(t, err)
defer assertNilF(t, rows.Close())
defer func() {
assertNilF(t, rows.Close())
}()
values := make([]driver.Value, 1)
assertNilF(t, rows.Next(values))
assertNotNilE(t, rows.Next(values)) // we deliberately check that there is an error, because we are in arrow batches mode
assertEqualE(t, values[0], nil)
})
})
Expand All @@ -128,7 +130,9 @@ func TestWithArrowBatchesWhenQueryReturnsSomeRowsInGivenFormat(t *testing.T) {
dbt.mustExec(forceJSON)
}
rows := dbt.mustQueryContext(WithArrowBatches(context.Background()), "SELECT 1")
defer assertNilF(t, rows.Close())
defer func() {
assertNilF(t, rows.Close())
}()
assertFalseF(t, rows.Next())
})
})
Expand Down
2 changes: 1 addition & 1 deletion connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ func (sc *snowflakeConn) cleanup() {
func (sc *snowflakeConn) Close() (err error) {
logger.WithContext(sc.ctx).Infoln("Close")
if err := sc.telemetry.sendBatch(); err != nil {
logger.WithContext(sc.ctx).Error(err)
logger.WithContext(sc.ctx).Errorf("error while sending telemetry. %v", err)
}
sc.stopHeartBeat()
defer sc.cleanup()
Expand Down
3 changes: 3 additions & 0 deletions connection_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ func (sc *snowflakeConn) processFileTransfer(

func getFileStream(ctx context.Context) (*bytes.Buffer, error) {
s := ctx.Value(fileStreamFile)
if s == nil {
return nil, nil
}
r, ok := s.(io.Reader)
if !ok {
return nil, errors.New("incorrect io.Reader")
Expand Down
16 changes: 12 additions & 4 deletions converter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2174,7 +2174,9 @@ func TestSmallTimestampBinding(t *testing.T) {
}

rows := sct.mustQueryContext(ctx, "SELECT ?", parameters)
defer assertNilF(t, rows.Close())
defer func() {
assertNilF(t, rows.Close())
}()

scanValues := make([]driver.Value, 1)
for {
Expand Down Expand Up @@ -2210,7 +2212,9 @@ func TestTimestampConversionWithoutArrowBatches(t *testing.T) {
t.Run(tp+"("+strconv.Itoa(scale)+")_"+tsStr, func(t *testing.T) {
query := fmt.Sprintf("SELECT '%s'::%s(%v)", tsStr, tp, scale)
rows := sct.mustQueryContext(ctx, query, nil)
defer assertNilF(t, rows.Close())
defer func() {
assertNilF(t, rows.Close())
}()

if rows.Next() {
var act time.Time
Expand Down Expand Up @@ -2290,7 +2294,9 @@ func TestTimestampConversionWithArrowBatchesMicrosecondPassesForDistantDates(t *
if err != nil {
t.Fatalf("failed to query: %v", err)
}
defer assertNilF(t, rows.Close())
defer func() {
assertNilF(t, rows.Close())
}()

// getting result batches
batches, err := rows.(*snowflakeRows).GetArrowBatches()
Expand Down Expand Up @@ -2349,7 +2355,9 @@ func TestTimestampConversionWithArrowBatchesAndWithOriginalTimestamp(t *testing.

query := fmt.Sprintf("SELECT '%s'::%s(%v)", tsStr, tp, scale)
rows := sct.mustQueryContext(ctx, query, []driver.NamedValue{})
defer assertNilF(t, rows.Close())
defer func() {
assertNilF(t, rows.Close())
}()

// getting result batches
batches, err := rows.(*snowflakeRows).GetArrowBatches()
Expand Down
Loading

0 comments on commit b1382de

Please sign in to comment.