Skip to content

Commit abf97a1

Browse files
authored
add ValidateRequestID hook to allow you to override the default request ID validation. (crewjam#599)
Inspired by crewjam#581, but modified to be consistent with other hooks
1 parent 9207a5e commit abf97a1

File tree

1 file changed

+27
-12
lines changed

1 file changed

+27
-12
lines changed

service_provider.go

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,10 @@ type ServiceProvider struct {
141141
// ValidateAudienceRestriction allows you to override the default audience validation
142142
// for an assertion. If nil, the default audience validation is used.
143143
ValidateAudienceRestriction func(assertion *Assertion) error
144+
145+
// ValidateRequestID allows you to override the default request ID validation.
146+
// If nil, the default request ID validation is used.
147+
ValidateRequestID func(response Response, possibleRequestIDs []string) error
144148
}
145149

146150
// MaxIssueDelay is the longest allowed time between when a SAML assertion is
@@ -972,18 +976,8 @@ func (sp *ServiceProvider) parseResponse(responseEl *etree.Element, possibleRequ
972976
}
973977
}
974978

975-
requestIDvalid := false
976-
if sp.AllowIDPInitiated {
977-
requestIDvalid = true
978-
} else {
979-
for _, possibleRequestID := range possibleRequestIDs {
980-
if response.InResponseTo == possibleRequestID {
981-
requestIDvalid = true
982-
}
983-
}
984-
}
985-
if !requestIDvalid {
986-
return nil, fmt.Errorf("`InResponseTo` does not match any of the possible request IDs (expected %v)", possibleRequestIDs)
979+
if err := sp.validateRequestID(response, possibleRequestIDs); err != nil {
980+
return nil, err
987981
}
988982

989983
if response.IssueInstant.Add(MaxIssueDelay).Before(now) {
@@ -1059,6 +1053,27 @@ func (sp *ServiceProvider) parseResponse(responseEl *etree.Element, possibleRequ
10591053
return &assertions[0], nil
10601054
}
10611055

1056+
func (sp *ServiceProvider) validateRequestID(response Response, possibleRequestIDs []string) error {
1057+
if sp.ValidateRequestID != nil {
1058+
return sp.ValidateRequestID(response, possibleRequestIDs)
1059+
}
1060+
1061+
requestIDvalid := false
1062+
if sp.AllowIDPInitiated {
1063+
requestIDvalid = true
1064+
} else {
1065+
for _, possibleRequestID := range possibleRequestIDs {
1066+
if response.InResponseTo == possibleRequestID {
1067+
requestIDvalid = true
1068+
}
1069+
}
1070+
}
1071+
if !requestIDvalid {
1072+
return fmt.Errorf("`InResponseTo` does not match any of the possible request IDs (expected %v)", possibleRequestIDs)
1073+
}
1074+
return nil
1075+
}
1076+
10621077
func (sp *ServiceProvider) parseEncryptedAssertion(encryptedAssertionEl *etree.Element, possibleRequestIDs []string, now time.Time, signatureRequirement signatureRequirement) (*Assertion, error) {
10631078
assertionEl, err := sp.decryptElement(encryptedAssertionEl)
10641079
if err != nil {

0 commit comments

Comments
 (0)