Skip to content

Commit

Permalink
Merge pull request from GHSA-4hq8-gmxx-h6w9
Browse files Browse the repository at this point in the history
This change validates that the XML input we receive is safe to parse before
passing it to the standard library's XML parsing functions or the etree DOM
parsing functions.

This validation mitigates critical vulnerabilities in `encoding/xml` - CVE-2020-29509, CVE-2020-29510, and CVE-2020-29511.

TODO: is there going to be a go.mod version assigned to this on release?
  • Loading branch information
crewjam authored Dec 14, 2020
1 parent a606939 commit da4f1a0
Show file tree
Hide file tree
Showing 9 changed files with 203 additions and 16 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ require (
github.com/dgrijalva/jwt-go v3.2.0+incompatible
github.com/jonboulle/clockwork v0.2.1 // indirect
github.com/kr/pretty v0.2.1
github.com/mattermost/xml-roundtrip-validator v0.0.0-00010101000000-000000000000
github.com/pkg/errors v0.8.1 // indirect
github.com/russellhaering/goxmldsig v1.1.0
github.com/stretchr/testify v1.6.1
Expand Down
6 changes: 6 additions & 0 deletions identity_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"time"

"github.com/beevik/etree"
xrv "github.com/mattermost/xml-roundtrip-validator"
dsig "github.com/russellhaering/goxmldsig"

"github.com/crewjam/saml/logger"
Expand Down Expand Up @@ -359,13 +360,18 @@ func NewIdpAuthnRequest(idp *IdentityProvider, r *http.Request) (*IdpAuthnReques
default:
return nil, fmt.Errorf("method not allowed")
}

return req, nil
}

// Validate checks that the authentication request is valid and assigns
// the AuthnRequest and Metadata properties. Returns a non-nil error if the
// request is not valid.
func (req *IdpAuthnRequest) Validate() error {
if err := xrv.Validate(bytes.NewReader(req.RequestBuffer)); err != nil {
return err
}

if err := xml.Unmarshal(req.RequestBuffer, &req.Request); err != nil {
return err
}
Expand Down
33 changes: 31 additions & 2 deletions identity_provider_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package saml

import (
"bytes"
"compress/flate"
"crypto"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"encoding/xml"
"fmt"
"io/ioutil"
"math/rand"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -232,15 +235,41 @@ func TestIDPHTTPCanHandleMetadataRequest(t *testing.T) {
func TestIDPHTTPCanHandleSSORequest(t *testing.T) {
test := NewIdentifyProviderTest()
w := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "https://idp.example.com/saml/sso?RelayState=ThisIsTheRelayState&SAMLRequest=lJJBayoxFIX%2FypC9JhnU5wszAz7lgWCLaNtFd5fMbQ1MkmnunVb%2FfUfbUqEgdhs%2BTr5zkmLW8S5s8KVD4mzvm0Cl6FIwEciRCeCRDFuznd2sTD5Upk2Ro42NyGZEmNjFMI%2BBOo9pi%2BnVWbzfrEqxY27JSEntEPfg2waHNnpJ4JtcgiWRLfoLXYBjwDfu6p%2B8JIoiWy5K4eqBUipXIzVRUwXKKtRK53qkJ3qqQVuNPUjU4TIQQ%2BBS5EqPBzofKH2ntBn%2FMervo8jWnyX%2BuVC78FwKkT1gopNKX1JUxSklXTMIfM0gsv8xeeDL%2BPGk7%2FF0Qg0GdnwQ1cW5PDLUwFDID6uquO1Dlot1bJw9%2FPLRmia%2BzRMCYyk4dSiq6205QSDXOxfy3KAq5Pkvqt4DAAD%2F%2Fw%3D%3D", nil)

const validRequest = `lJJBayoxFIX%2FypC9JhnU5wszAz7lgWCLaNtFd5fMbQ1MkmnunVb%2FfUfbUqEgdhs%2BTr5zkmLW8S5s8KVD4mzvm0Cl6FIwEciRCeCRDFuznd2sTD5Upk2Ro42NyGZEmNjFMI%2BBOo9pi%2BnVWbzfrEqxY27JSEntEPfg2waHNnpJ4JtcgiWRLfoLXYBjwDfu6p%2B8JIoiWy5K4eqBUipXIzVRUwXKKtRK53qkJ3qqQVuNPUjU4TIQQ%2BBS5EqPBzofKH2ntBn%2FMervo8jWnyX%2BuVC78FwKkT1gopNKX1JUxSklXTMIfM0gsv8xeeDL%2BPGk7%2FF0Qg0GdnwQ1cW5PDLUwFDID6uquO1Dlot1bJw9%2FPLRmia%2BzRMCYyk4dSiq6205QSDXOxfy3KAq5Pkvqt4DAAD%2F%2Fw%3D%3D`

r, _ := http.NewRequest("GET", "https://idp.example.com/saml/sso?RelayState=ThisIsTheRelayState&"+
"SAMLRequest="+validRequest, nil)
test.IDP.Handler().ServeHTTP(w, r)
assert.Equal(t, http.StatusOK, w.Code)

// rejects requests that are invalid
w = httptest.NewRecorder()
r, _ = http.NewRequest("GET", "https://idp.example.com/saml/sso?RelayState=ThisIsTheRelayState&SAMLRequest=PEF1dGhuUmVxdWVzdA%3D%3D", nil)
r, _ = http.NewRequest("GET", "https://idp.example.com/saml/sso?RelayState=ThisIsTheRelayState&"+
"SAMLRequest=PEF1dGhuUmVxdWVzdA%3D%3D", nil)
test.IDP.Handler().ServeHTTP(w, r)
assert.Equal(t, http.StatusBadRequest, w.Code)

// rejects requests that contain malformed XML
{
a, _ := url.QueryUnescape(validRequest)
b, _ := base64.StdEncoding.DecodeString(a)
c, _ := ioutil.ReadAll(flate.NewReader(bytes.NewReader(b)))
d := bytes.Replace(c, []byte("<AuthnRequest"), []byte("<AuthnRequest ::foo=\"bar\""), 1)
f := bytes.Buffer{}
e, _ := flate.NewWriter(&f, flate.DefaultCompression)
e.Write(d)
e.Close()
g := base64.StdEncoding.EncodeToString(f.Bytes())
invalidRequest := url.QueryEscape(g)

w = httptest.NewRecorder()
r, _ = http.NewRequest("GET", "https://idp.example.com/saml/sso?RelayState=ThisIsTheRelayState&"+
"SAMLRequest="+invalidRequest, nil)
test.IDP.Handler().ServeHTTP(w, r)
assert.Equal(t, http.StatusBadRequest, w.Code)
}

}

func TestIDPCanHandleRequestWithNewSession(t *testing.T) {
Expand Down
20 changes: 11 additions & 9 deletions samlidp/util.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
package samlidp

import (
"bytes"
"encoding/xml"
"errors"
"io"
"io/ioutil"

"encoding/xml"

"io"
xrv "github.com/mattermost/xml-roundtrip-validator"

"github.com/crewjam/saml"
)
Expand All @@ -20,19 +21,20 @@ func randomBytes(n int) []byte {
}

func getSPMetadata(r io.Reader) (spMetadata *saml.EntityDescriptor, err error) {
var bytes []byte

if bytes, err = ioutil.ReadAll(r); err != nil {
var data []byte
if data, err = ioutil.ReadAll(r); err != nil {
return nil, err
}

spMetadata = &saml.EntityDescriptor{}
if err := xrv.Validate(bytes.NewBuffer(data)); err != nil {
return nil, err
}

if err := xml.Unmarshal(bytes, &spMetadata); err != nil {
if err := xml.Unmarshal(data, &spMetadata); err != nil {
if err.Error() == "expected element type <EntityDescriptor> but have <EntitiesDescriptor>" {
entities := &saml.EntitiesDescriptor{}

if err := xml.Unmarshal(bytes, &entities); err != nil {
if err := xml.Unmarshal(data, &entities); err != nil {
return nil, err
}

Expand Down
22 changes: 22 additions & 0 deletions samlidp/util_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package samlidp

import (
"strings"
"testing"

"github.com/stretchr/testify/assert"
)

func TestGetSPMetadata(t *testing.T) {
good := "" +
"<EntityDescriptor xmlns=\"urn:oasis:names:tc:SAML:2.0:metadata\" validUntil=\"2013-03-10T00:32:19.104Z\" cacheDuration=\"PT1H\" entityID=\"http://localhost:5000/e087a985171710fb9fb30f30f41384f9/saml2/metadata/\">\n" +
"</EntityDescriptor>"
_, err := getSPMetadata(strings.NewReader(good))
assert.NoError(t, err)

bad := "" +
"<EntityDescriptor xmlns=\"urn:oasis:names:tc:SAML:2.0:metadata\" ::attr=\"foo\" validUntil=\"2013-03-10T00:32:19.104Z\" cacheDuration=\"PT1H\" entityID=\"http://localhost:5000/e087a985171710fb9fb30f30f41384f9/saml2/metadata/\">\n" +
"</EntityDescriptor>"
_, err = getSPMetadata(strings.NewReader(bad))
assert.EqualError(t, err, "validator: in token starting at 1:1: roundtrip error: expected {{ EntityDescriptor} [{{ xmlns} urn:oasis:names:tc:SAML:2.0:metadata} {{ :attr} foo} {{ validUntil} 2013-03-10T00:32:19.104Z} {{ cacheDuration} PT1H} {{ entityID} http://localhost:5000/e087a985171710fb9fb30f30f41384f9/saml2/metadata/}]}, observed {{ EntityDescriptor} [{{ xmlns} urn:oasis:names:tc:SAML:2.0:metadata} {{ attr} foo} {{ validUntil} 2013-03-10T00:32:19.104Z} {{ cacheDuration} PT1H} {{ entityID} http://localhost:5000/e087a985171710fb9fb30f30f41384f9/saml2/metadata/}]}")
}
7 changes: 7 additions & 0 deletions samlsp/fetch_metadata.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package samlsp

import (
"bytes"
"context"
"encoding/xml"
"errors"
Expand All @@ -9,6 +10,7 @@ import (
"net/url"

"github.com/crewjam/httperr"
xrv "github.com/mattermost/xml-roundtrip-validator"

"github.com/crewjam/saml"
)
Expand All @@ -20,6 +22,11 @@ import (
// <EntityDescriptor>.
func ParseMetadata(data []byte) (*saml.EntityDescriptor, error) {
entity := &saml.EntityDescriptor{}

if err := xrv.Validate(bytes.NewBuffer(data)); err != nil {
return nil, err
}

err := xml.Unmarshal(data, entity)

// this comparison is ugly, but it is how the error is generated in encoding/xml
Expand Down
17 changes: 17 additions & 0 deletions samlsp/fetch_metadata_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -25,3 +26,19 @@ func TestFetchMetadata(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, "https://idp.testshib.org/idp/shibboleth", md.EntityID)
}

func TestFetchMetadataRejectsInvalid(t *testing.T) {
test := NewMiddlewareTest()
test.IDPMetadata = strings.Replace(test.IDPMetadata, "<EntityDescriptor ", "<EntityDescriptor ::foo=\"bar\"", -1)

testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/metadata", r.URL.String())
fmt.Fprint(w, test.IDPMetadata)
}))

fmt.Println(testServer.URL + "/metadata")
u, _ := url.Parse(testServer.URL + "/metadata")
md, err := FetchMetadata(context.Background(), testServer.Client(), *u)
assert.EqualError(t, err, "validator: in token starting at 2:1: roundtrip error: expected {{ EntityDescriptor} [{{ :foo} bar} {{ xmlns} urn:oasis:names:tc:SAML:2.0:metadata} {{xmlns ds} http://www.w3.org/2000/09/xmldsig#} {{xmlns mdalg} urn:oasis:names:tc:SAML:metadata:algsupport} {{xmlns mdui} urn:oasis:names:tc:SAML:metadata:ui} {{xmlns shibmd} urn:mace:shibboleth:metadata:1.0} {{xmlns xsi} http://www.w3.org/2001/XMLSchema-instance} {{ Name} urn:mace:shibboleth:testshib:two} {{ entityID} https://idp.testshib.org/idp/shibboleth}]}, observed {{ EntityDescriptor} [{{ foo} bar} {{ xmlns} urn:oasis:names:tc:SAML:2.0:metadata} {{xmlns ds} http://www.w3.org/2000/09/xmldsig#} {{xmlns mdalg} urn:oasis:names:tc:SAML:metadata:algsupport} {{xmlns mdui} urn:oasis:names:tc:SAML:metadata:ui} {{xmlns shibmd} urn:mace:shibboleth:metadata:1.0} {{xmlns xsi} http://www.w3.org/2001/XMLSchema-instance} {{ Name} urn:mace:shibboleth:testshib:two} {{ entityID} https://idp.testshib.org/idp/shibboleth} {{ entityID} https://idp.testshib.org/idp/shibboleth}]}")
assert.Nil(t, md)
}
38 changes: 33 additions & 5 deletions service_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@ import (
"errors"
"fmt"
"html/template"
"io/ioutil"
"net/http"
"net/url"
"regexp"
"time"

xrv "github.com/mattermost/xml-roundtrip-validator"

"github.com/beevik/etree"
dsig "github.com/russellhaering/goxmldsig"
"github.com/russellhaering/goxmldsig/etreeutils"
Expand Down Expand Up @@ -553,9 +556,15 @@ func (sp *ServiceProvider) ParseXMLResponse(decodedResponseXML []byte, possibleR
Response: string(decodedResponseXML),
}

// ensure that the response XML is well formed before we parse it
if err := xrv.Validate(bytes.NewReader(decodedResponseXML)); err != nil {
retErr.PrivateErr = fmt.Errorf("invalid xml: %s", err)
return nil, retErr
}

// do some validation first before we decrypt
resp := Response{}
if err := xml.Unmarshal([]byte(decodedResponseXML), &resp); err != nil {
if err := xml.Unmarshal(decodedResponseXML, &resp); err != nil {
retErr.PrivateErr = fmt.Errorf("cannot unmarshal response: %s", err)
return nil, retErr
}
Expand Down Expand Up @@ -659,6 +668,12 @@ func (sp *ServiceProvider) ParseXMLResponse(decodedResponseXML []byte, possibleR
}
retErr.Response = string(plaintextAssertion)

// TODO(ross): add test case for this
if err := xrv.Validate(bytes.NewReader(plaintextAssertion)); err != nil {
retErr.PrivateErr = fmt.Errorf("plaintext response contains invalid XML: %s", err)
return nil, retErr
}

doc = etree.NewDocument()
if err := doc.ReadFromBytes(plaintextAssertion); err != nil {
retErr.PrivateErr = fmt.Errorf("cannot parse plaintext response %v", err)
Expand All @@ -673,6 +688,8 @@ func (sp *ServiceProvider) ParseXMLResponse(decodedResponseXML []byte, possibleR
}

assertion = &Assertion{}
// Note: plaintextAssertion is known to be safe to parse because
// plaintextAssertion is unmodified from when xrv.Validate() was called above.
if err := xml.Unmarshal(plaintextAssertion, assertion); err != nil {
retErr.PrivateErr = err
return nil, retErr
Expand Down Expand Up @@ -1001,8 +1018,12 @@ func (sp *ServiceProvider) ValidateLogoutResponseForm(postFormData string) error
return fmt.Errorf("unable to parse base64: %s", err)
}

var resp LogoutResponse
// TODO(ross): add test case for this (SLO does not have tests right now)
if err := xrv.Validate(bytes.NewReader(rawResponseBuf)); err != nil {
return fmt.Errorf("response contains invalid XML: %s", err)
}

var resp LogoutResponse
if err := xml.Unmarshal(rawResponseBuf, &resp); err != nil {
return fmt.Errorf("cannot unmarshal response: %s", err)
}
Expand Down Expand Up @@ -1034,9 +1055,16 @@ func (sp *ServiceProvider) ValidateLogoutResponseRedirect(queryParameterData str
return fmt.Errorf("unable to parse base64: %s", err)
}

gr := flate.NewReader(bytes.NewBuffer(rawResponseBuf))
gr, err := ioutil.ReadAll(flate.NewReader(bytes.NewBuffer(rawResponseBuf)))
if err != nil {
return err
}

if err := xrv.Validate(bytes.NewReader(gr)); err != nil {
return err
}

decoder := xml.NewDecoder(gr)
decoder := xml.NewDecoder(bytes.NewReader(gr))

var resp LogoutResponse

Expand All @@ -1050,7 +1078,7 @@ func (sp *ServiceProvider) ValidateLogoutResponseRedirect(queryParameterData str
}

doc := etree.NewDocument()
if _, err := doc.ReadFrom(gr); err != nil {
if _, err := doc.ReadFrom(bytes.NewReader(gr)); err != nil {
return err
}

Expand Down
Loading

0 comments on commit da4f1a0

Please sign in to comment.