Skip to content

Commit

Permalink
Exec() now provides access to status of multiple statements. (#1309)
Browse files Browse the repository at this point in the history
It now reports the last inserted ID and affected row count for all statements,
not just the last one. This is useful to execute batches of statements such as
UPDATE with minimal roundtrips.

Co-authored-by: Inada Naoki <songofacandy@gmail.com>
  • Loading branch information
mherr-google and methane authored May 29, 2023
1 parent f43effa commit 397e2f5
Show file tree
Hide file tree
Showing 9 changed files with 259 additions and 50 deletions.
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,22 @@ Allow multiple statements in one query. This can be used to bach multiple querie

When `multiStatements` is used, `?` parameters must only be used in the first statement. [interpolateParams](#interpolateparams) can be used to avoid this limitation unless prepared statement is used explicitly.

It's possible to access the last inserted ID and number of affected rows for multiple statements by using `sql.Conn.Raw()` and the `mysql.Result`. For example:

```go
conn, _ := db.Conn(ctx)
conn.Raw(func(conn interface{}) error {
ex := conn.(driver.Execer)
res, err := ex.Exec(`
UPDATE point SET x = 1 WHERE y = 2;
UPDATE point SET x = 2 WHERE y = 3;
`, nil)
// Both slices have 2 elements.
log.Print(res.(mysql.Result).AllRowsAffected())
log.Print(res.(mysql.Result).AllLastInsertIds())
})
```

##### `parseTime`

```
Expand Down
6 changes: 3 additions & 3 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
case 1:
switch authData[0] {
case cachingSha2PasswordFastAuthSuccess:
if err = mc.readResultOK(); err == nil {
if err = mc.resultUnchanged().readResultOK(); err == nil {
return nil // auth successful
}

Expand Down Expand Up @@ -397,7 +397,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
return err
}
}
return mc.readResultOK()
return mc.resultUnchanged().readResultOK()

default:
return ErrMalformPkt
Expand Down Expand Up @@ -426,7 +426,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
if err != nil {
return err
}
return mc.readResultOK()
return mc.resultUnchanged().readResultOK()
}

default:
Expand Down
29 changes: 15 additions & 14 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@ import (
type mysqlConn struct {
buf buffer
netConn net.Conn
rawConn net.Conn // underlying connection when netConn is TLS connection.
affectedRows uint64
insertId uint64
rawConn net.Conn // underlying connection when netConn is TLS connection.
result mysqlResult // managed by clearResult() and handleOkPacket().
cfg *Config
connector *connector
maxAllowedPacket int
Expand Down Expand Up @@ -155,6 +154,7 @@ func (mc *mysqlConn) cleanup() {
if err := mc.netConn.Close(); err != nil {
mc.cfg.Logger.Print(err)
}
mc.clearResult()
}

func (mc *mysqlConn) error() error {
Expand Down Expand Up @@ -316,28 +316,25 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err
}
query = prepared
}
mc.affectedRows = 0
mc.insertId = 0

err := mc.exec(query)
if err == nil {
return &mysqlResult{
affectedRows: int64(mc.affectedRows),
insertId: int64(mc.insertId),
}, err
copied := mc.result
return &copied, err
}
return nil, mc.markBadConn(err)
}

// Internal function to execute commands
func (mc *mysqlConn) exec(query string) error {
handleOk := mc.clearResult()
// Send command
if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
return mc.markBadConn(err)
}

// Read Result
resLen, err := mc.readResultSetHeaderPacket()
resLen, err := handleOk.readResultSetHeaderPacket()
if err != nil {
return err
}
Expand All @@ -354,14 +351,16 @@ func (mc *mysqlConn) exec(query string) error {
}
}

return mc.discardResults()
return handleOk.discardResults()
}

func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
return mc.query(query, args)
}

func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
handleOk := mc.clearResult()

if mc.closed.Load() {
mc.cfg.Logger.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
Expand All @@ -382,7 +381,7 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
if err == nil {
// Read Result
var resLen int
resLen, err = mc.readResultSetHeaderPacket()
resLen, err = handleOk.readResultSetHeaderPacket()
if err == nil {
rows := new(textRows)
rows.mc = mc
Expand Down Expand Up @@ -410,12 +409,13 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
// The returned byte slice is only valid until the next read
func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
// Send command
handleOk := mc.clearResult()
if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
return nil, err
}

// Read Result
resLen, err := mc.readResultSetHeaderPacket()
resLen, err := handleOk.readResultSetHeaderPacket()
if err == nil {
rows := new(textRows)
rows.mc = mc
Expand Down Expand Up @@ -466,11 +466,12 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
}
defer mc.finish()

handleOk := mc.clearResult()
if err = mc.writeCommandPacket(comPing); err != nil {
return mc.markBadConn(err)
}

return mc.readResultOK()
return handleOk.readResultOK()
}

// BeginTx implements driver.ConnBeginTx interface
Expand Down
112 changes: 112 additions & 0 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2154,11 +2154,51 @@ func TestRejectReadOnly(t *testing.T) {
}

func TestPing(t *testing.T) {
ctx := context.Background()
runTests(t, dsn, func(dbt *DBTest) {
if err := dbt.db.Ping(); err != nil {
dbt.fail("Ping", "Ping", err)
}
})

runTests(t, dsn, func(dbt *DBTest) {
conn, err := dbt.db.Conn(ctx)
if err != nil {
dbt.fail("db", "Conn", err)
}

// Check that affectedRows and insertIds are cleared after each call.
conn.Raw(func(conn interface{}) error {
c := conn.(*mysqlConn)

// Issue a query that sets affectedRows and insertIds.
q, err := c.Query(`SELECT 1`, nil)
if err != nil {
dbt.fail("Conn", "Query", err)
}
if got, want := c.result.affectedRows, []int64{0}; !reflect.DeepEqual(got, want) {
dbt.Fatalf("bad affectedRows: got %v, want=%v", got, want)
}
if got, want := c.result.insertIds, []int64{0}; !reflect.DeepEqual(got, want) {
dbt.Fatalf("bad insertIds: got %v, want=%v", got, want)
}
q.Close()

// Verify that Ping() clears both fields.
for i := 0; i < 2; i++ {
if err := c.Ping(ctx); err != nil {
dbt.fail("Pinger", "Ping", err)
}
if got, want := c.result.affectedRows, []int64(nil); !reflect.DeepEqual(got, want) {
t.Errorf("bad affectedRows: got %v, want=%v", got, want)
}
if got, want := c.result.insertIds, []int64(nil); !reflect.DeepEqual(got, want) {
t.Errorf("bad affectedRows: got %v, want=%v", got, want)
}
}
return nil
})
})
}

// See Issue #799
Expand Down Expand Up @@ -2378,6 +2418,42 @@ func TestMultiResultSetNoSelect(t *testing.T) {
})
}

func TestExecMultipleResults(t *testing.T) {
ctx := context.Background()
runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
dbt.mustExec(`
CREATE TABLE test (
id INT NOT NULL AUTO_INCREMENT,
value VARCHAR(255),
PRIMARY KEY (id)
)`)
conn, err := dbt.db.Conn(ctx)
if err != nil {
t.Fatalf("failed to connect: %v", err)
}
conn.Raw(func(conn interface{}) error {
ex := conn.(driver.Execer)
res, err := ex.Exec(`
INSERT INTO test (value) VALUES ('a'), ('b');
INSERT INTO test (value) VALUES ('c'), ('d'), ('e');
`, nil)
if err != nil {
t.Fatalf("insert statements failed: %v", err)
}
mres := res.(Result)
if got, want := mres.AllRowsAffected(), []int64{2, 3}; !reflect.DeepEqual(got, want) {
t.Errorf("bad AllRowsAffected: got %v, want=%v", got, want)
}
// For INSERTs containing multiple rows, LAST_INSERT_ID() returns the
// first inserted ID, not the last.
if got, want := mres.AllLastInsertIds(), []int64{1, 3}; !reflect.DeepEqual(got, want) {
t.Errorf("bad AllLastInsertIds: got %v, want %v", got, want)
}
return nil
})
})
}

// tests if rows are set in a proper state if some results were ignored before
// calling rows.NextResultSet.
func TestSkipResults(t *testing.T) {
Expand All @@ -2399,6 +2475,42 @@ func TestSkipResults(t *testing.T) {
})
}

func TestQueryMultipleResults(t *testing.T) {
ctx := context.Background()
runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
dbt.mustExec(`
CREATE TABLE test (
id INT NOT NULL AUTO_INCREMENT,
value VARCHAR(255),
PRIMARY KEY (id)
)`)
conn, err := dbt.db.Conn(ctx)
if err != nil {
t.Fatalf("failed to connect: %v", err)
}
conn.Raw(func(conn interface{}) error {
qr := conn.(driver.Queryer)

c := conn.(*mysqlConn)

// Demonstrate that repeated queries reset the affectedRows
for i := 0; i < 2; i++ {
_, err := qr.Query(`
INSERT INTO test (value) VALUES ('a'), ('b');
INSERT INTO test (value) VALUES ('c'), ('d'), ('e');
`, nil)
if err != nil {
t.Fatalf("insert statements failed: %v", err)
}
if got, want := c.result.affectedRows, []int64{2, 3}; !reflect.DeepEqual(got, want) {
t.Errorf("bad affectedRows: got %v, want=%v", got, want)
}
}
return nil
})
})
}

func TestPingContext(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
ctx, cancel := context.WithCancel(context.Background())
Expand Down
8 changes: 4 additions & 4 deletions infile.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func deferredClose(err *error, closer io.Closer) {

const defaultPacketSize = 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP

func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
func (mc *okHandler) handleInFileRequest(name string) (err error) {
var rdr io.Reader
var data []byte
packetSize := defaultPacketSize
Expand Down Expand Up @@ -154,7 +154,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
for err == nil {
n, err = rdr.Read(data[4:])
if n > 0 {
if ioErr := mc.writePacket(data[:4+n]); ioErr != nil {
if ioErr := mc.conn().writePacket(data[:4+n]); ioErr != nil {
return ioErr
}
}
Expand All @@ -168,7 +168,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
if data == nil {
data = make([]byte, 4)
}
if ioErr := mc.writePacket(data[:4]); ioErr != nil {
if ioErr := mc.conn().writePacket(data[:4]); ioErr != nil {
return ioErr
}

Expand All @@ -177,6 +177,6 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
return mc.readResultOK()
}

mc.readPacket()
mc.conn().readPacket()
return err
}
Loading

0 comments on commit 397e2f5

Please sign in to comment.