diff --git a/cdc/entry/schema_storage_test.go b/cdc/entry/schema_storage_test.go index 6e2f35a414b..a975d14eaf4 100644 --- a/cdc/entry/schema_storage_test.go +++ b/cdc/entry/schema_storage_test.go @@ -1011,25 +1011,24 @@ func TestHandleKey(t *testing.T) { } func TestGetPrimaryKey(t *testing.T) { - t.Parallel() - helper := NewSchemaTestHelper(t) defer helper.Close() - + // PKISHandle is true, primary key is also the handle, since it's integer type. sql := `create table test.t1(a int primary key, b int)` - job := helper.DDL2Job(sql) - tableInfo := model.WrapTableInfo(0, "test", 0, job.BinlogInfo.TableInfo) + event := helper.DDL2Event(sql) - names := tableInfo.GetPrimaryKeyColumnNames() - require.Len(t, names, 1) - require.Containsf(t, names, "a", "names: %v", names) + names := event.TableInfo.GetPrimaryKeyColumnNames() + require.Equal(t, names, []string{"a"}) + // IsCommonHandle is true, primary key is not the handle, since it contains multiple fields. sql = `create table test.t2(a int, b int, c int, primary key(a, b))` - job = helper.DDL2Job(sql) - tableInfo = model.WrapTableInfo(0, "test", 0, job.BinlogInfo.TableInfo) - - names = tableInfo.GetPrimaryKeyColumnNames() - require.Len(t, names, 2) - require.Containsf(t, names, "a", "names: %v", names) - require.Containsf(t, names, "b", "names: %v", names) + event = helper.DDL2Event(sql) + names = event.TableInfo.GetPrimaryKeyColumnNames() + require.Equal(t, names, []string{"a", "b"}) + + // IsCommonHandle is true, primary key is not the handle, since it's not integer type. + sql = `create table test.t3(a varchar(10) primary key, b int)` + event = helper.DDL2Event(sql) + names = event.TableInfo.GetPrimaryKeyColumnNames() + require.Equal(t, names, []string{"a"}) } diff --git a/cdc/model/schema_storage.go b/cdc/model/schema_storage.go index c43c9562c92..17fda710c23 100644 --- a/cdc/model/schema_storage.go +++ b/cdc/model/schema_storage.go @@ -369,22 +369,16 @@ func (ti *TableInfo) OffsetsByNames(names []string) ([]int, bool) { // GetPrimaryKeyColumnNames returns the primary key column names func (ti *TableInfo) GetPrimaryKeyColumnNames() []string { - result := make([]string, 0) - for _, index := range ti.Indices { - if index.Primary { - for _, col := range index.Columns { - result = append(result, col.Name.O) - } - return result - } + var result []string + if ti.PKIsHandle { + result = append(result, ti.GetPkColInfo().Name.O) + return result } - for _, columnsOffsets := range ti.IndexColumnsOffset { - for _, offset := range columnsOffsets { - columnInfo := ti.Columns[offset] - if mysql.HasPriKeyFlag(columnInfo.FieldType.GetFlag()) { - result = append(result, columnInfo.Name.O) - } + indexInfo := ti.GetPrimaryKey() + if indexInfo != nil { + for _, col := range indexInfo.Columns { + result = append(result, col.Name.O) } } return result