Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transaction watcher and notifier #16363

Merged
merged 10 commits into from
Jul 17, 2024
Prev Previous commit
Next Next commit
test: added transaction watcher tests
Signed-off-by: Harshit Gangal <harshit@planetscale.com>
  • Loading branch information
harshit-gangal committed Jul 12, 2024
commit b24906be3dd96164ffdeba063d355a4a9103745c
26 changes: 0 additions & 26 deletions go/vt/vttablet/endtoend/framework/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ import (

"vitess.io/vitess/go/mysql"
"vitess.io/vitess/go/vt/dbconfigs"
"vitess.io/vitess/go/vt/vtgate/fakerpcvtgateconn"
"vitess.io/vitess/go/vt/vtgate/vtgateconn"
"vitess.io/vitess/go/vt/vttablet/tabletserver"
"vitess.io/vitess/go/vt/vttablet/tabletserver/tabletenv"

Expand All @@ -51,8 +49,6 @@ var (
Server *tabletserver.TabletServer
// ServerAddress is the http URL for the server.
ServerAddress string
// ResolveChan is the channel that sends dtids that are to be resolved.
ResolveChan = make(chan string, 1)
// TopoServer is the topology for the server
TopoServer *topo.Server
)
Expand All @@ -61,15 +57,6 @@ var (
// all the global variables. This function should only be called
// once at the beginning of the test.
func StartCustomServer(ctx context.Context, connParams, connAppDebugParams mysql.ConnParams, dbName string, cfg *tabletenv.TabletConfig) error {
// Setup a fake vtgate server.
protocol := "resolveTest"
vtgateconn.SetVTGateProtocol(protocol)
vtgateconn.RegisterDialer(protocol, func(context.Context, string) (vtgateconn.Impl, error) {
return &txResolver{
FakeVTGateConn: fakerpcvtgateconn.FakeVTGateConn{},
}, nil
})

dbcfgs := dbconfigs.NewTestDBConfigs(connParams, connAppDebugParams, dbName)

Target = &querypb.Target{
Expand Down Expand Up @@ -137,16 +124,3 @@ func StartServer(ctx context.Context, connParams, connAppDebugParams mysql.ConnP
func StopServer() {
Server.StopService()
}

// txResolver transmits dtids to be resolved through ResolveChan.
type txResolver struct {
fakerpcvtgateconn.FakeVTGateConn
}

func (conn *txResolver) ResolveTransaction(ctx context.Context, dtid string) error {
select {
case ResolveChan <- dtid:
default:
}
return nil
}
66 changes: 36 additions & 30 deletions go/vt/vttablet/endtoend/transaction_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -612,8 +612,8 @@ func TestMMRollbackFlow(t *testing.T) {
require.NoError(t, err)
}

func TestTransactionWatcher(t *testing.T) {
t.Skip("TODO: need to update this test")
// TestTransactionWatcherSignal test that unresolved transaction signal is received via health stream.
func TestTransactionWatcherSignal(t *testing.T) {
client := framework.NewClient()

query := "insert into vitess_test (intval, floatval, charval, binval) " +
Expand All @@ -623,44 +623,50 @@ func TestTransactionWatcher(t *testing.T) {
_, err = client.Execute(query, nil)
require.NoError(t, err)

start := time.Now()
err = client.CreateTransaction("aa", []*querypb.Target{{
Keyspace: "test1",
Shard: "0",
}, {
Keyspace: "test2",
Shard: "1",
}})
ch := make(chan any)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
err := client.StreamHealthWithContext(ctx, func(shr *querypb.StreamHealthResponse) error {
if shr.RealtimeStats.TxUnresolved {
ch <- true
}
return nil
})
require.NoError(t, err)
}()

err = client.CreateTransaction("aa", []*querypb.Target{
{Keyspace: "test1", Shard: "0"},
{Keyspace: "test2", Shard: "1"}})
require.NoError(t, err)

// The watchdog should kick in after 1 second.
dtid := <-framework.ResolveChan
if dtid != "aa" {
t.Errorf("dtid: %s, want aa", dtid)
}
diff := time.Since(start)
if diff < 1*time.Second {
t.Errorf("diff: %v, want greater than 1s", diff)
// wait for unresolved transaction signal
select {
case <-ch:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for transaction watcher signal")
}

err = client.SetRollback("aa", 0)
require.NoError(t, err)

// still should receive unresolved transaction signal
select {
case <-ch:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for transaction watcher signal")
}

err = client.ConcludeTransaction("aa")
require.NoError(t, err)

// Make sure the watchdog stops sending messages.
// Check twice. Sometimes, a race can still cause
// a stray message.
dtid = ""
for i := 0; i < 2; i++ {
select {
case dtid = <-framework.ResolveChan:
continue
case <-time.After(2 * time.Second):
return
}
// transaction watcher should stop sending singal now.
select {
case <-ch:
t.Fatal("unexpected signal for unresolved transaction")
case <-time.After(2 * time.Second):
}
t.Errorf("Unexpected message: %s", dtid)
}

func TestUnresolvedTracking(t *testing.T) {
Expand Down
20 changes: 17 additions & 3 deletions go/vt/vttablet/tabletserver/dt_executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,9 @@ func TestExecutorReadAllTransactions(t *testing.T) {
}
}

func TestExecutorResolveTransaction(t *testing.T) {
// TestTransactionNotifier tests that the transaction notifier is called
// when a transaction watcher receives unresolved transaction count more than zero.
func TestTransactionNotifier(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

Expand All @@ -457,17 +459,29 @@ func TestExecutorResolveTransaction(t *testing.T) {
db.AddQueryPattern(
"select count\\(\\*\\) from _vt\\.redo_state where time_created.*",
sqltypes.MakeTestResult(sqltypes.MakeTestFields("count(*)", "int64"), "0"))

// zero unresolved transactions
db.AddQueryPattern(
"select count\\(\\*\\) from _vt\\.dt_state where time_created.*",
sqltypes.MakeTestResult(sqltypes.MakeTestFields("count(*)", "int64"), "1"))
sqltypes.MakeTestResult(sqltypes.MakeTestFields("count(*)", "int64"), "0"))
notifyCh := make(chan any)
tsv.te.dxNotify = func() {
notifyCh <- nil
}
select {
case <-notifyCh:
t.Error("unresolved transaction notifier call unexpected")
case <-time.After(1 * time.Second):
}

// non zero unresolved transactions
db.AddQueryPattern(
"select count\\(\\*\\) from _vt\\.dt_state where time_created.*",
sqltypes.MakeTestResult(sqltypes.MakeTestFields("count(*)", "int64"), "1"))
select {
case <-notifyCh:
case <-time.After(1 * time.Second):
t.Error("unresolved transaction notifier not called")
t.Error("unresolved transaction notifier expected but not received")
}
}

Expand Down
Loading