From 8a2a2fd914fe26fc1258d0cfd290ba7a8aeec2c0 Mon Sep 17 00:00:00 2001 From: Vincent Xiao Date: Fri, 9 Oct 2020 16:56:35 -0700 Subject: [PATCH] fix: bug in updating a customer's phones and addresses --- pkg/customers/customers.go | 62 ++++++++++++++++++++++++++++----- pkg/customers/customers_test.go | 21 ++++++----- 2 files changed, 67 insertions(+), 16 deletions(-) diff --git a/pkg/customers/customers.go b/pkg/customers/customers.go index 69761e931..177545622 100644 --- a/pkg/customers/customers.go +++ b/pkg/customers/customers.go @@ -486,6 +486,7 @@ func (r *sqlCustomerRepository) updateCustomer(c *client.Customer, organization if err != nil { return err } + defer tx.Rollback() query := `update customers set first_name = ?, middle_name = ?, last_name = ?, nick_name = ?, suffix = ?, type = ?, birth_date = ?, status = ?, email =?, last_modified = ?, @@ -528,17 +529,39 @@ func (r *sqlCustomerRepository) updateCustomer(c *client.Customer, organization } func (r *sqlCustomerRepository) updatePhonesByCustomerID(tx *sql.Tx, customerID string, phones []client.Phone) error { - query := `replace into customers_phones (customer_id, number, valid, type) values (?, ?, ?, ?);` - stmt, err := tx.Prepare(query) + deleteQuery := `delete from customers_phones where customer_id = ?` + var args []interface{} + args = append(args, customerID) + if len(phones) > 0 { + deleteQuery = fmt.Sprintf("%s and number not in (?%s)", deleteQuery, strings.Repeat(",?", len(phones)-1)) + for _, p := range phones { + args = append(args, p.Number) + } + } + deleteQuery = fmt.Sprintf("%s;", deleteQuery) + + stmt, err := tx.Prepare(deleteQuery) + if err != nil { + return fmt.Errorf("preparing query: %v", err) + } + defer stmt.Close() + + _, err = stmt.Exec(args...) + if err != nil { + return fmt.Errorf("executing query: %v", err) + } + + replaceQuery := `replace into customers_phones (customer_id, number, valid, type) values (?, ?, ?, ?);` + stmt, err = tx.Prepare(replaceQuery) if err != nil { - return fmt.Errorf("preparing tx update on customers_phones err=%v | rollback=%v", err, tx.Rollback()) + return fmt.Errorf("preparing query: %v", err) } defer stmt.Close() for _, phone := range phones { _, err := stmt.Exec(customerID, phone.Number, phone.Valid, phone.Type) if err != nil { - return fmt.Errorf("executing update on customers_phones err=%v | rollback=%v", err, tx.Rollback()) + return fmt.Errorf("executing update on customer's phone: %v", err) } } @@ -546,20 +569,43 @@ func (r *sqlCustomerRepository) updatePhonesByCustomerID(tx *sql.Tx, customerID } func (r *sqlCustomerRepository) updateAddressesByCustomerID(tx *sql.Tx, customerID string, addresses []client.CustomerAddress) error { - query := `replace into customers_addresses(address_id, customer_id, type, address1, address2, city, state, postal_code, country, validated) values (?, ?, ?, ?, ?, ?, ?, ?, + deleteQuery := `delete from customers_addresses where customer_id = ?` + var args []interface{} + args = append(args, customerID) + if len(addresses) > 0 { + deleteQuery = fmt.Sprintf("%s and address1 not in (?%s)", deleteQuery, strings.Repeat(",?", len(addresses)-1)) + for _, a := range addresses { + args = append(args, a.Address1) + } + } + deleteQuery = fmt.Sprintf("%s;", deleteQuery) + + stmt, err := tx.Prepare(deleteQuery) + if err != nil { + return fmt.Errorf("preparing query: %v", err) + } + defer stmt.Close() + + _, err = stmt.Exec(args...) + if err != nil { + panic(err) + } + + replaceQuery := `replace into customers_addresses(address_id, customer_id, type, address1, address2, city, state, postal_code, country, validated) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?);` - stmt, err := tx.Prepare(query) + stmt, err = tx.Prepare(replaceQuery) if err != nil { - return fmt.Errorf("preparing tx on customers_addresses err=%v | rollback=%v", err, tx.Rollback()) + return fmt.Errorf("preparing query: %v", err) } defer stmt.Close() for _, addr := range addresses { _, err := stmt.Exec(addr.AddressID, customerID, addr.Type, addr.Address1, addr.Address2, addr.City, addr.State, addr.PostalCode, addr.Country, addr.Validated) if err != nil { - return fmt.Errorf("executing update on customers_addresses err=%v | rollback=%v", err, tx.Rollback()) + return fmt.Errorf("executing query: %v", err) } } + return nil } diff --git a/pkg/customers/customers_test.go b/pkg/customers/customers_test.go index bfe905fe2..9ac7c9a02 100644 --- a/pkg/customers/customers_test.go +++ b/pkg/customers/customers_test.go @@ -446,6 +446,12 @@ func TestCustomers__updateCustomer(t *testing.T) { updateReq.FirstName = "Jim" updateReq.LastName = "Smith" updateReq.Email = "jim@google.com" + updateReq.Phones = []phone{ + { + Number: "555.555.5555", + Type: "cell", + }, + } updateReq.Addresses = []address{ { Address1: "555 5th st", @@ -455,7 +461,6 @@ func TestCustomers__updateCustomer(t *testing.T) { Country: "US", }, } - payload, err := json.Marshal(&updateReq) require.NoError(t, err) @@ -465,19 +470,19 @@ func TestCustomers__updateCustomer(t *testing.T) { AddCustomerRoutes(log.NewNopLogger(), router, repo, testCustomerSSNStorage(t), createTestOFACSearcher(nil, nil)) router.ServeHTTP(w, req) w.Flush() - require.Equal(t, http.StatusOK, w.Code) var got *client.Customer require.NoError(t, json.NewDecoder(w.Body).Decode(&got)) - fmt.Println(w.Body.String()) - - want, err := repo.GetCustomer(customer.CustomerID) + want, _, _ := updateReq.asCustomer(testCustomerSSNStorage(t)) require.NoError(t, err) - got.CreatedAt = want.CreatedAt - got.LastModified = want.LastModified - got.Metadata = make(map[string]string) + want.CustomerID = got.CustomerID + want.Addresses[0].AddressID = got.Addresses[0].AddressID + want.BirthDate = got.BirthDate + want.Status = got.Status + want.CreatedAt = got.CreatedAt + want.LastModified = got.LastModified require.Equal(t, want, got) }