Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vitessdriver: add support for DistributedTx #9451

Merged
merged 7 commits into from
Jan 10, 2022
Prev Previous commit
Next Next commit
vitessdriver: serialization & special-cased query
Signed-off-by: Derek Perkins <derek@nozzle.io>
  • Loading branch information
derekperkins committed Dec 30, 2021
commit 65ccf9fbb8d075184549868f8727b3aecdc85e89
75 changes: 73 additions & 2 deletions go/vt/vitessdriver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,14 @@ import (
"context"
"database/sql"
"database/sql/driver"
"encoding/base64"
"encoding/json"
"errors"

"google.golang.org/grpc"
"google.golang.org/protobuf/proto"
"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"

"vitess.io/vitess/go/vt/vtgate/grpcvtgateconn"
"vitess.io/vitess/go/vt/vtgate/vtgateconn"
Expand Down Expand Up @@ -238,8 +242,27 @@ func (c *conn) Close() error {
return nil
}

// GetDistributedTx allows users to send sessions over the wire and reconnect to an existing transaction
func GetDistributedTx(ctx context.Context, session *vtgateconn.VTGateSession) (*sql.Tx, error) {
// SessionTokenFromTx serializes the session on the tx, which can be reconstituted
// into a *sql.Tx using DistributedTxFromSessionToken
func SessionTokenFromTx(ctx context.Context, tx *sql.Tx) (string, error) {
var sessionToken string

err := tx.QueryRowContext(ctx, "vt_session_token").Scan(&sessionToken)
if err != nil {
return "", err
}

return sessionToken, nil
}

// DistributedTxFromSessionToken allows users to send serialized sessions over the wire and
// reconnect to an existing transaction
func DistributedTxFromSessionToken(ctx context.Context, sessionToken string) (*sql.Tx, error) {
session, err := sessionTokenToSession(sessionToken)
if err != nil {
return nil, err
}

db, err := OpenWithConfiguration(Configuration{
// include session here - there will be a new *DB created each time
// that stores the session state in the &conn{} struct
Expand All @@ -258,6 +281,49 @@ func GetDistributedTx(ctx context.Context, session *vtgateconn.VTGateSession) (*
return c.BeginTx(ctx, nil)
}

func newSessionTokenRow(session *vtgateconn.VTGateSession, c *converter) (driver.Rows, error) {
sessionToken, err := sessionToSessionToken(session)
if err != nil {
return nil, err
}

qr := sqltypes.Result{
Fields: []*querypb.Field{{
Name: "vt_session_token",
Type: sqltypes.VarBinary,
}},
Rows: [][]sqltypes.Value{{
sqltypes.NewVarBinary(sessionToken),
}},
}

return newRows(&qr, c), nil
}

func sessionToSessionToken(session *vtgateconn.VTGateSession) (string, error) {
b, err := proto.Marshal(session)
if err != nil {
return "", err
}

return base64.StdEncoding.EncodeToString(b), nil
}

func sessionTokenToSession(sessionToken string) (*vtgateconn.VTGateSession, error) {
b, err := base64.StdEncoding.DecodeString(sessionToken)
if err != nil {
return nil, err
}

var session *vtgateconn.VTGateSession
err = proto.Unmarshal(b, session)
if err != nil {
return nil, err
}

return session, nil
}

func (c *conn) Begin() (driver.Tx, error) {
if _, err := c.Exec("begin", nil); err != nil {
return nil, err
Expand Down Expand Up @@ -341,6 +407,11 @@ func (c *conn) Query(query string, args []driver.Value) (driver.Rows, error) {
}

func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
// special case for serializing the current session state
if query == "vt_session_token" {
return newSessionTokenRow(c.session, c.convert)
}

bv, err := c.convert.bindVarsFromNamedValues(args)
if err != nil {
return nil, err
Expand Down