Skip to content

Commit

Permalink
chore(spanner): refactor table replace on column attrs change
Browse files Browse the repository at this point in the history
  • Loading branch information
newtonnthiga committed Aug 15, 2024
1 parent 9beea05 commit 90a23ad
Show file tree
Hide file tree
Showing 5 changed files with 391 additions and 77 deletions.
106 changes: 93 additions & 13 deletions internal/spanner/alis_google_spanner_table_resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"github.com/hashicorp/terraform-plugin-framework/resource"
"github.com/hashicorp/terraform-plugin-framework/resource/schema"
"github.com/hashicorp/terraform-plugin-framework/resource/schema/booldefault"
"github.com/hashicorp/terraform-plugin-framework/resource/schema/boolplanmodifier"
"github.com/hashicorp/terraform-plugin-framework/resource/schema/listplanmodifier"
"github.com/hashicorp/terraform-plugin-framework/resource/schema/planmodifier"
"github.com/hashicorp/terraform-plugin-framework/resource/schema/stringplanmodifier"
"github.com/hashicorp/terraform-plugin-framework/schema/validator"
Expand Down Expand Up @@ -167,9 +167,6 @@ func (r *spannerTableResource) Schema(_ context.Context, _ resource.SchemaReques
"Multiple columns can be specified as primary keys to create a composite primary key.\n" +
"Primary key columns must be non-null.\n" +
"**Changing this value will cause a table replace**.",
PlanModifiers: []planmodifier.Bool{
boolplanmodifier.RequiresReplace(),
},
},
"is_computed": schema.BoolAttribute{
Optional: true,
Expand All @@ -178,9 +175,6 @@ func (r *spannerTableResource) Schema(_ context.Context, _ resource.SchemaReques
"A common use case is to generate a column from a PROTO column field.\n" +
"This should be accompanied by a `computation_ddl` field.\n" +
"**Changing this value will cause a table replace**.",
PlanModifiers: []planmodifier.Bool{
boolplanmodifier.RequiresReplace(),
},
},
"computation_ddl": schema.StringAttribute{
Optional: true,
Expand All @@ -189,9 +183,6 @@ func (r *spannerTableResource) Schema(_ context.Context, _ resource.SchemaReques
"The expression must be a valid SQL expression that generates a value for the column.\n" +
"Example: `column1 + column2`, or `proto_column.field`.\n" +
"**Changing this value will cause a table replace**.",
PlanModifiers: []planmodifier.String{
stringplanmodifier.RequiresReplace(),
},
},
"auto_increment": schema.BoolAttribute{
Optional: true,
Expand All @@ -210,9 +201,6 @@ func (r *spannerTableResource) Schema(_ context.Context, _ resource.SchemaReques
Description: "The data type of the column.\n" +
"Valid types are: `BOOL`, `INT64`, `FLOAT64`, `STRING`, `BYTES`, `DATE`, `TIMESTAMP`, `JSON`, `PROTO`, `ARRAY<STRING>`, `ARRAY<INT64>`, `ARRAY<FLOAT32>`, `ARRAY<FLOAT64>`.\n" +
"**Changing this value will cause a table replace**.",
PlanModifiers: []planmodifier.String{
stringplanmodifier.RequiresReplace(),
},
},
"size": schema.Int64Attribute{
Optional: true,
Expand Down Expand Up @@ -263,6 +251,98 @@ func (r *spannerTableResource) Schema(_ context.Context, _ resource.SchemaReques
},
},
Description: "The columns of the table.",
PlanModifiers: []planmodifier.List{
listplanmodifier.RequiresReplaceIf(func(ctx context.Context, req planmodifier.ListRequest, resp *listplanmodifier.RequiresReplaceIfFuncResponse) {
// Create a map of the columns by name
type PriorAndCurrentColumns struct {
Prior *spannerTableColumn
Current *spannerTableColumn
}
columnsMap := make(map[string]*PriorAndCurrentColumns)

// Get the columns prior to the plan
priorColumns := make([]spannerTableColumn, 0, len(req.StateValue.Elements()))
d := req.StateValue.ElementsAs(ctx, &priorColumns, false)
if d.HasError() {
resp.Diagnostics.Append(d...)
return
}
for _, column := range priorColumns {
if _, ok := columnsMap[column.Name.ValueString()]; !ok {
columnsMap[column.Name.ValueString()] = &PriorAndCurrentColumns{}
}
columnsMap[column.Name.ValueString()].Prior = &column
}

// Get the columns after the plan
currentColumns := make([]spannerTableColumn, 0, len(req.PlanValue.Elements()))
d = req.PlanValue.ElementsAs(ctx, &currentColumns, false)
if d.HasError() {
resp.Diagnostics.Append(d...)
return
}
for _, column := range currentColumns {
if _, ok := columnsMap[column.Name.ValueString()]; !ok {
columnsMap[column.Name.ValueString()] = &PriorAndCurrentColumns{}
}
columnsMap[column.Name.ValueString()].Current = &column
}

// Check if the columns are the same.
// Columns that are new do not require a replace, unless a primary key is added.
// Columns that are removed do not require a replace, unless they are part of the primary key.
// Columns that are updated require a replace if: the column type is changed,
// the primary key status is changed, or the column's computation_ddl is changed.
for name, columns := range columnsMap {
// Column is new
if columns.Prior == nil && columns.Current != nil {
// Check if the column is a primary key
if !columns.Current.IsPrimaryKey.IsNull() && columns.Current.IsPrimaryKey.ValueBool() {
resp.RequiresReplace = true
resp.Diagnostics.AddWarning(fmt.Sprintf("Column %q requires a table replace", name), fmt.Sprintf("Column %q is a new primary key column and requires a table replace", name))
}
continue
}

// Column is removed
if columns.Current == nil && columns.Prior != nil {
// Check if the column is a primary key
if !columns.Prior.IsPrimaryKey.IsNull() && columns.Prior.IsPrimaryKey.ValueBool() {
resp.RequiresReplace = true
resp.Diagnostics.AddWarning(fmt.Sprintf("Column %q requires a table replace", name), fmt.Sprintf("Column %q is a removed primary key column and requires a table replace", name))
}
continue
}

// Column type is changed
// Type is required, so we can safely assume it is not null
if columns.Prior.Type.ValueString() != columns.Current.Type.ValueString() {
resp.RequiresReplace = true
resp.Diagnostics.AddWarning(fmt.Sprintf("Column %q requires a table replace", name), fmt.Sprintf("Column %q has a changed type and requires a table replace", name))
}

// Column primary key status is changed
// This is not required, so we also need to check if it is null
if (!columns.Prior.IsPrimaryKey.IsNull() && !columns.Current.IsPrimaryKey.IsNull() && columns.Prior.IsPrimaryKey.ValueBool() != columns.Current.IsPrimaryKey.ValueBool()) ||
(columns.Prior.IsPrimaryKey.IsNull() && !columns.Current.IsPrimaryKey.IsNull() && columns.Current.IsPrimaryKey.ValueBool()) ||
(!columns.Prior.IsPrimaryKey.IsNull() && columns.Prior.IsPrimaryKey.ValueBool() && columns.Current.IsPrimaryKey.IsNull()) {
resp.RequiresReplace = true
resp.Diagnostics.AddWarning(fmt.Sprintf("Column %q requires a table replace", name), fmt.Sprintf("Column %q has a changed primary key status and requires a table replace", name))
}

// Column is computed and computation_ddl is changed
// Both fields are required but only if at least one is set
if (!columns.Prior.IsComputed.IsNull() && columns.Prior.IsComputed.ValueBool() && !columns.Current.IsComputed.IsNull() && columns.Current.IsComputed.ValueBool() &&
columns.Prior.ComputationDdl.ValueString() != columns.Current.ComputationDdl.ValueString()) ||
(!columns.Prior.IsComputed.IsNull() && columns.Prior.IsComputed.ValueBool() && (columns.Current.IsComputed.IsNull() || !columns.Current.IsComputed.ValueBool())) {
resp.RequiresReplace = true
resp.Diagnostics.AddWarning(fmt.Sprintf("Column %q requires a table replace", name), fmt.Sprintf("Column %q has a changed computation_ddl or is_computed has been disabled and requires a table replace", name))
}
}

},
"If certain values of any of the columns change, Terraform will destroy and recreate the table.", "If certain values of any of the columns change, Terraform will destroy and recreate the table."),
},
},
},
Description: "The schema of the table.",
Expand Down
130 changes: 94 additions & 36 deletions internal/spanner/services/services.go
Original file line number Diff line number Diff line change
Expand Up @@ -2235,7 +2235,7 @@ func (s *SpannerService) DeleteSpannerTableIndex(ctx context.Context, parent str
return &emptypb.Empty{}, nil
}

func (s *SpannerService) CreateSpannerTableForeignKeysConstraint(ctx context.Context, parent string, constraint *SpannerTableForeignKeysConstraint) (*SpannerTableForeignKeysConstraint, error) {
func (s *SpannerService) CreateSpannerTableForeignKeyConstraint(ctx context.Context, parent string, constraint *SpannerTableForeignKeyConstraint) (*SpannerTableForeignKeyConstraint, error) {
// Validate parent
googleSqlParentValid := utils.ValidateArgument(parent, utils.SpannerGoogleSqlTableNameRegex)
postgresSqlParentValid := utils.ValidateArgument(parent, utils.SpannerPostgresSqlTableNameRegex)
Expand All @@ -2251,40 +2251,33 @@ func (s *SpannerService) CreateSpannerTableForeignKeysConstraint(ctx context.Con
if !googleSqlConstraintIdValid && !postgresSqlConstraintIdValid {
return nil, status.Errorf(codes.InvalidArgument, "Invalid argument constraint.name (%s), must match `%s` for GoogleSql dialect or `%s` for PostgreSQL dialect", constraint.Name, utils.SpannerGoogleSqlConstraintIdRegex, utils.SpannerPostgresSqlConstraintIdRegex)
}
if constraint.ForeignKeys == nil || len(constraint.ForeignKeys) == 0 {
return nil, status.Error(codes.InvalidArgument, "Invalid argument constraint.foreign_keys, field is required but not provided")
}
// Validate foreign key fields
for i, foreignKey := range constraint.ForeignKeys {
if foreignKey == nil {
return nil, status.Errorf(codes.InvalidArgument, "Invalid argument constraint.foreign_keys[%d], field is required but not provided", i)
}
if foreignKey.ReferencedTable == "" {
return nil, status.Errorf(codes.InvalidArgument, "Invalid argument constraint.foreign_keys[%d].referenced_table, field is required but not provided", i)
}
googleSqlForeignKeyTableValid := utils.ValidateArgument(foreignKey.ReferencedTable, utils.SpannerGoogleSqlTableNameRegex)
postgresSqlForeignKeyTableValid := utils.ValidateArgument(foreignKey.ReferencedTable, utils.SpannerPostgresSqlTableNameRegex)
if !googleSqlForeignKeyTableValid && !postgresSqlForeignKeyTableValid {
return nil, status.Errorf(codes.InvalidArgument, "Invalid argument constraint.foreign_keys[%d].referenced_table (%s), must match `%s` for GoogleSql dialect or `%s` for PostgreSQL dialect", i, foreignKey.ReferencedTable, utils.SpannerGoogleSqlTableNameRegex, utils.SpannerPostgresSqlTableNameRegex)
}

if foreignKey.ReferencedColumn == "" {
return nil, status.Errorf(codes.InvalidArgument, "Invalid argument constraint.foreign_keys[%d].referenced_columns, field is required but not provided", i)
}
googleSqlForeignKeyColumnValid := utils.ValidateArgument(foreignKey.ReferencedColumn, utils.SpannerGoogleSqlColumnIdRegex)
postgresSqlForeignKeyColumnValid := utils.ValidateArgument(foreignKey.ReferencedColumn, utils.SpannerPostgresSqlColumnIdRegex)
if !googleSqlForeignKeyColumnValid && !postgresSqlForeignKeyColumnValid {
return nil, status.Errorf(codes.InvalidArgument, "Invalid argument constraint.foreign_keys[%d].referenced_columns (%s), must match `%s` for GoogleSql dialect or `%s` for PostgreSQL dialect", i, foreignKey.ReferencedColumn, utils.SpannerGoogleSqlColumnIdRegex, utils.SpannerPostgresSqlColumnIdRegex)
}
if constraint.ReferencedTable == "" {
return nil, status.Error(codes.InvalidArgument, "Invalid argument constraint.referenced_table, field is required but not provided")
}
googleSqlForeignKeyTableValid := utils.ValidateArgument(constraint.ReferencedTable, utils.SpannerGoogleSqlTableIdRegex)
postgresSqlForeignKeyTableValid := utils.ValidateArgument(constraint.ReferencedTable, utils.SpannerPostgresSqlTableIdRegex)
if !googleSqlForeignKeyTableValid && !postgresSqlForeignKeyTableValid {
return nil, status.Errorf(codes.InvalidArgument, "Invalid argument constraint.referenced_table (%s), must match `%s` for GoogleSql dialect or `%s` for PostgreSQL dialect", constraint.ReferencedTable, utils.SpannerGoogleSqlTableNameRegex, utils.SpannerPostgresSqlTableNameRegex)
}

if foreignKey.Column == "" {
return nil, status.Errorf(codes.InvalidArgument, "Invalid argument constraint.foreign_keys[%d].column, field is required but not provided", i)
}
googleSqlColumnValid := utils.ValidateArgument(foreignKey.Column, utils.SpannerGoogleSqlColumnIdRegex)
postgresSqlColumnValid := utils.ValidateArgument(foreignKey.Column, utils.SpannerPostgresSqlColumnIdRegex)
if !googleSqlColumnValid && !postgresSqlColumnValid {
return nil, status.Errorf(codes.InvalidArgument, "Invalid argument constraint.foreign_keys[%d].column (%s), must match `%s` for GoogleSql dialect or `%s` for PostgreSQL dialect", i, foreignKey.Column, utils.SpannerGoogleSqlColumnIdRegex, utils.SpannerPostgresSqlColumnIdRegex)
}
if constraint.ReferencedColumn == "" {
return nil, status.Error(codes.InvalidArgument, "Invalid argument constraint.referenced_column, field is required but not provided")
}
googleSqlForeignKeyColumnValid := utils.ValidateArgument(constraint.ReferencedColumn, utils.SpannerGoogleSqlColumnIdRegex)
postgresSqlForeignKeyColumnValid := utils.ValidateArgument(constraint.ReferencedColumn, utils.SpannerPostgresSqlColumnIdRegex)
if !googleSqlForeignKeyColumnValid && !postgresSqlForeignKeyColumnValid {
return nil, status.Errorf(codes.InvalidArgument, "Invalid argument constraint.referenced_column (%s), must match `%s` for GoogleSql dialect or `%s` for PostgreSQL dialect", constraint.ReferencedColumn, utils.SpannerGoogleSqlColumnIdRegex, utils.SpannerPostgresSqlColumnIdRegex)
}

if constraint.Column == "" {
return nil, status.Error(codes.InvalidArgument, "Invalid argument constraint.column, field is required but not provided")
}
googleSqlColumnValid := utils.ValidateArgument(constraint.Column, utils.SpannerGoogleSqlColumnIdRegex)
postgresSqlColumnValid := utils.ValidateArgument(constraint.Column, utils.SpannerPostgresSqlColumnIdRegex)
if !googleSqlColumnValid && !postgresSqlColumnValid {
return nil, status.Errorf(codes.InvalidArgument, "Invalid argument constraint.column (%s), must match `%s` for GoogleSql dialect or `%s` for PostgreSQL dialect", constraint.Column, utils.SpannerGoogleSqlColumnIdRegex, utils.SpannerPostgresSqlColumnIdRegex)
}

// Deconstruct parent name to get project, instance, database and table
Expand All @@ -2310,14 +2303,79 @@ func (s *SpannerService) CreateSpannerTableForeignKeysConstraint(ctx context.Con
return nil, status.Errorf(codes.Internal, "Error connecting to database: %v", err)
}

sqlStatement := fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s", tableId, constraint.Name)
for _, foreignKey := range constraint.ForeignKeys {
sqlStatement += fmt.Sprintf(" FOREIGN KEY (%s) REFERENCES %s(%s)", foreignKey.Column, foreignKey.ReferencedTable, foreignKey.ReferencedColumn)
sqlStatement := fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s(%s)", tableId, constraint.Name, constraint.Column, constraint.ReferencedTable, constraint.ReferencedColumn)
if constraint.OnDelete != SpannerTableForeignKeyConstraintActionUnspecified {
sqlStatement += fmt.Sprintf(" ON DELETE %s", constraint.OnDelete.String())
}

if err := db.Exec(sqlStatement).Error; err != nil {
return nil, status.Errorf(codes.Internal, "Error creating foreign key constraint: %v", err)
}

return constraint, nil
}

func (s *SpannerService) GetSpannerTableForeignKeyConstraint(ctx context.Context, parent string, name string) (*SpannerTableForeignKeyConstraint, error) {
// Validate parent
googleSqlParentValid := utils.ValidateArgument(parent, utils.SpannerGoogleSqlTableNameRegex)
postgresSqlParentValid := utils.ValidateArgument(parent, utils.SpannerPostgresSqlTableNameRegex)
if !googleSqlParentValid && !postgresSqlParentValid {
return nil, status.Errorf(codes.InvalidArgument, "Invalid argument parent (%s), must match `%s` for GoogleSql dialect or `%s` for PostgreSQL dialect", parent, utils.SpannerGoogleSqlTableNameRegex, utils.SpannerPostgresSqlTableNameRegex)
}

// Validate name
googleSqlConstraintIdValid := utils.ValidateArgument(name, utils.SpannerGoogleSqlConstraintIdRegex)
postgresSqlConstraintIdValid := utils.ValidateArgument(name, utils.SpannerPostgresSqlConstraintIdRegex)
if !googleSqlConstraintIdValid && !postgresSqlConstraintIdValid {
return nil, status.Errorf(codes.InvalidArgument, "Invalid argument name (%s), must match `%s` for GoogleSql dialect or `%s` for PostgreSQL dialect", name, utils.SpannerGoogleSqlConstraintIdRegex, utils.SpannerPostgresSqlConstraintIdRegex)
}

// Deconstruct parent name to get project, instance, database and table
parentNameParts := strings.Split(parent, "/")
project := parentNameParts[1]
instance := parentNameParts[3]
databaseId := parentNameParts[5]
tableId := parentNameParts[7]

db, err := gorm.Open(
spannergorm.New(
spannergorm.Config{
DriverName: "spanner",
DSN: fmt.Sprintf("projects/%s/instances/%s/databases/%s", project, instance, databaseId),
},
),
&gorm.Config{
PrepareStmt: true,
Logger: tfLogger,
},
)
if err != nil {
return nil, status.Errorf(codes.Internal, "Error connecting to database: %v", err)
}

sqlStatement := `
SELECT
TABLE_CONSTRAINTS.CONSTRAINT_NAME,
TABLE_CONSTRAINTS.TABLE_NAME,
TABLE_CONSTRAINTS.CONSTRAINT_TYPE,
REFERENTIAL_CONSTRAINTS.UPDATE_RULE,
REFERENTIAL_CONSTRAINTS.DELETE_RULE
FROM
INFORMATION_SCHEMA.TABLE_CONSTRAINTS
INNER JOIN
INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS
ON
TABLE_CONSTRAINTS.CONSTRAINT_NAME = REFERENTIAL_CONSTRAINTS.CONSTRAINT_NAME
WHERE TABLE_CONSTRAINTS.TABLE_NAME = ? and TABLE_CONSTRAINTS.CONSTRAINT_NAME = ? AND TABLE_CONSTRAINTS.CONSTRAINT_TYPE = "FOREIGN KEY"
`

var result *Constraint
db = db.Raw(sqlStatement, tableId, name).Scan(&result)
if db.Error != nil {
return nil, status.Errorf(codes.Internal, "Error getting foreign key constraint: %v", db.Error)
}
if result == nil {
return nil, status.Errorf(codes.NotFound, "Foreign key constraint %s not found", name)
}

return nil, nil
}
Loading

0 comments on commit 90a23ad

Please sign in to comment.