From 3bdbc4eef9a6fc310d82ffa5df7703ff934c9b1e Mon Sep 17 00:00:00 2001 From: Onsi Fakhouri Date: Tue, 29 Oct 2024 13:42:10 -0600 Subject: [PATCH] stop memoizing result of HaveField fixes #787 --- matchers/have_field.go | 36 +++++++++++++++++++++++------------- matchers/have_field_test.go | 6 ++++-- 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/matchers/have_field.go b/matchers/have_field.go index 6989f78c4..8dd3f871a 100644 --- a/matchers/have_field.go +++ b/matchers/have_field.go @@ -17,7 +17,7 @@ func (e missingFieldError) Error() string { return string(e) } -func extractField(actual interface{}, field string, matchername string) (interface{}, error) { +func extractField(actual interface{}, field string, matchername string) (any, error) { fields := strings.SplitN(field, ".", 2) actualValue := reflect.ValueOf(actual) @@ -64,36 +64,46 @@ func extractField(actual interface{}, field string, matchername string) (interfa type HaveFieldMatcher struct { Field string Expected interface{} +} - extractedField interface{} - expectedMatcher omegaMatcher +func (matcher *HaveFieldMatcher) expectedMatcher() omegaMatcher { + var isMatcher bool + expectedMatcher, isMatcher := matcher.Expected.(omegaMatcher) + if !isMatcher { + expectedMatcher = &EqualMatcher{Expected: matcher.Expected} + } + return expectedMatcher } func (matcher *HaveFieldMatcher) Match(actual interface{}) (success bool, err error) { - matcher.extractedField, err = extractField(actual, matcher.Field, "HaveField") + extractedField, err := extractField(actual, matcher.Field, "HaveField") if err != nil { return false, err } - var isMatcher bool - matcher.expectedMatcher, isMatcher = matcher.Expected.(omegaMatcher) - if !isMatcher { - matcher.expectedMatcher = &EqualMatcher{Expected: matcher.Expected} - } - - return matcher.expectedMatcher.Match(matcher.extractedField) + return matcher.expectedMatcher().Match(extractedField) } func (matcher *HaveFieldMatcher) FailureMessage(actual interface{}) (message string) { + extractedField, err := extractField(actual, matcher.Field, "HaveField") + if err != nil { + // this really shouldn't happen + return fmt.Sprintf("Failed to extract field '%s': %s", matcher.Field, err) + } message = fmt.Sprintf("Value for field '%s' failed to satisfy matcher.\n", matcher.Field) - message += matcher.expectedMatcher.FailureMessage(matcher.extractedField) + message += matcher.expectedMatcher().FailureMessage(extractedField) return message } func (matcher *HaveFieldMatcher) NegatedFailureMessage(actual interface{}) (message string) { + extractedField, err := extractField(actual, matcher.Field, "HaveField") + if err != nil { + // this really shouldn't happen + return fmt.Sprintf("Failed to extract field '%s': %s", matcher.Field, err) + } message = fmt.Sprintf("Value for field '%s' satisfied matcher, but should not have.\n", matcher.Field) - message += matcher.expectedMatcher.NegatedFailureMessage(matcher.extractedField) + message += matcher.expectedMatcher().NegatedFailureMessage(extractedField) return message } diff --git a/matchers/have_field_test.go b/matchers/have_field_test.go index c49bc8041..c979944ad 100644 --- a/matchers/have_field_test.go +++ b/matchers/have_field_test.go @@ -140,15 +140,17 @@ var _ = Describe("HaveField", func() { }) Describe("Failure Messages", func() { - It("renders the underlying matcher failure", func() { + It("renders the underlying matcher failure without caching the object", func() { matcher := HaveField("Title", "Les Mis") success, err := matcher.Match(book) Ω(success).Should(BeFalse()) Ω(err).ShouldNot(HaveOccurred()) + book.Title = "Les Miser" msg := matcher.FailureMessage(book) - Ω(msg).Should(Equal("Value for field 'Title' failed to satisfy matcher.\nExpected\n : Les Miserables\nto equal\n : Les Mis")) + Ω(msg).Should(Equal("Value for field 'Title' failed to satisfy matcher.\nExpected\n : Les Miser\nto equal\n : Les Mis")) + book.Title = "Les Miserables" matcher = HaveField("Title", "Les Miserables") success, err = matcher.Match(book) Ω(success).Should(BeTrue())