@@ -141,6 +141,10 @@ type ServiceProvider struct {
141
141
// ValidateAudienceRestriction allows you to override the default audience validation
142
142
// for an assertion. If nil, the default audience validation is used.
143
143
ValidateAudienceRestriction func (assertion * Assertion ) error
144
+
145
+ // ValidateRequestID allows you to override the default request ID validation.
146
+ // If nil, the default request ID validation is used.
147
+ ValidateRequestID func (response Response , possibleRequestIDs []string ) error
144
148
}
145
149
146
150
// MaxIssueDelay is the longest allowed time between when a SAML assertion is
@@ -972,18 +976,8 @@ func (sp *ServiceProvider) parseResponse(responseEl *etree.Element, possibleRequ
972
976
}
973
977
}
974
978
975
- requestIDvalid := false
976
- if sp .AllowIDPInitiated {
977
- requestIDvalid = true
978
- } else {
979
- for _ , possibleRequestID := range possibleRequestIDs {
980
- if response .InResponseTo == possibleRequestID {
981
- requestIDvalid = true
982
- }
983
- }
984
- }
985
- if ! requestIDvalid {
986
- return nil , fmt .Errorf ("`InResponseTo` does not match any of the possible request IDs (expected %v)" , possibleRequestIDs )
979
+ if err := sp .validateRequestID (response , possibleRequestIDs ); err != nil {
980
+ return nil , err
987
981
}
988
982
989
983
if response .IssueInstant .Add (MaxIssueDelay ).Before (now ) {
@@ -1059,6 +1053,27 @@ func (sp *ServiceProvider) parseResponse(responseEl *etree.Element, possibleRequ
1059
1053
return & assertions [0 ], nil
1060
1054
}
1061
1055
1056
+ func (sp * ServiceProvider ) validateRequestID (response Response , possibleRequestIDs []string ) error {
1057
+ if sp .ValidateRequestID != nil {
1058
+ return sp .ValidateRequestID (response , possibleRequestIDs )
1059
+ }
1060
+
1061
+ requestIDvalid := false
1062
+ if sp .AllowIDPInitiated {
1063
+ requestIDvalid = true
1064
+ } else {
1065
+ for _ , possibleRequestID := range possibleRequestIDs {
1066
+ if response .InResponseTo == possibleRequestID {
1067
+ requestIDvalid = true
1068
+ }
1069
+ }
1070
+ }
1071
+ if ! requestIDvalid {
1072
+ return fmt .Errorf ("`InResponseTo` does not match any of the possible request IDs (expected %v)" , possibleRequestIDs )
1073
+ }
1074
+ return nil
1075
+ }
1076
+
1062
1077
func (sp * ServiceProvider ) parseEncryptedAssertion (encryptedAssertionEl * etree.Element , possibleRequestIDs []string , now time.Time , signatureRequirement signatureRequirement ) (* Assertion , error ) {
1063
1078
assertionEl , err := sp .decryptElement (encryptedAssertionEl )
1064
1079
if err != nil {
0 commit comments