Skip to content

Commit

Permalink
fix: add special case for MySQL
Browse files Browse the repository at this point in the history
  • Loading branch information
hperl committed Aug 10, 2023
1 parent bf7b3ef commit 802e3a3
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 7 deletions.
6 changes: 5 additions & 1 deletion consent/strategy_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -1010,7 +1010,11 @@ func (s *DefaultStrategy) performBackChannelLogoutAndDeleteSession(r *http.Reque
} else if err != nil {
return err
} else {
_ = s.r.Kratos().DisableSession(ctx, session.KratosSessionID.String())
innerErr := s.r.Kratos().DisableSession(ctx, session.KratosSessionID.String())
if innerErr != nil {
s.r.Logger().WithError(innerErr).WithField("sid", sid).Error("Unable to revoke session in ORY Kratos.")
}
// We don't return the error here because we don't want to break the logout flow if Kratos is down.
}

return nil
Expand Down
2 changes: 1 addition & 1 deletion flow/consent_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ type HandledLoginRequest struct {
// If specified, we will use this value to propagate the logout.
//
// required: false
KratosSessionID string `json:"session_id"`
KratosSessionID string `json:"kratos_session_id,omitempty"`

// ForceSubjectIdentifier forces the "pairwise" user ID of the end-user that authenticated. The "pairwise" user ID refers to the
// (Pairwise Identifier Algorithm)[http://openid.net/specs/openid-connect-core-1_0.html#PairwiseAlg] of the OpenID
Expand Down
4 changes: 2 additions & 2 deletions oauth2/fosite_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func setupRegistries(t *testing.T) {
}

func TestManagers(t *testing.T) {
ctx := context.TODO()
ctx := context.Background()
tests := []struct {
name string
enableSessionEncrypted bool
Expand All @@ -67,7 +67,7 @@ func TestManagers(t *testing.T) {
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Run("suite="+tc.name, func(t *testing.T) {
setupRegistries(t)

require.NoError(t, registries["memory"].ClientManager().CreateClient(context.Background(), &client.Client{LegacyClientID: "foobar"})) // this is a workaround because the client is not being created for memory store by test helpers.
Expand Down
1 change: 0 additions & 1 deletion persistence/sql/migratest/migration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ func CompareWithFixture(t *testing.T, actual interface{}, prefix string, id stri
}

func TestMigrations(t *testing.T) {
//pop.Debug = true
connections := make(map[string]*pop.Connection, 1)

if testing.Short() {
Expand Down
35 changes: 35 additions & 0 deletions persistence/sql/persister_consent.go
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,11 @@ func (p *Persister) DeleteLoginSession(ctx context.Context, id string) (deletedS
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteLoginSession")
defer otelx.End(span, &err)

if p.Connection(ctx).Dialect.Name() == "mysql" {
// MySQL does not support RETURNING.
return p.mySQLDeleteLoginSession(ctx, id)
}

var session flow.LoginSession

err = p.Connection(ctx).RawQuery(
Expand All @@ -477,6 +482,36 @@ func (p *Persister) DeleteLoginSession(ctx context.Context, id string) (deletedS
return &session, nil
}

func (p *Persister) mySQLDeleteLoginSession(ctx context.Context, id string) (*flow.LoginSession, error) {
var session flow.LoginSession

err := p.Connection(ctx).Transaction(func(tx *pop.Connection) error {
err := tx.RawQuery(`
SELECT * FROM hydra_oauth2_authentication_session
WHERE id = ? AND nid = ?`,
id,
p.NetworkID(ctx),
).First(&session)
if err != nil {
return err
}

return p.Connection(ctx).RawQuery(`
DELETE FROM hydra_oauth2_authentication_session
WHERE id = ? AND nid = ?`,
id,
p.NetworkID(ctx),
).Exec()
})

if err != nil {
return nil, sqlcon.HandleError(err)
}

return &session, nil

}

func (p *Persister) FindGrantedAndRememberedConsentRequests(ctx context.Context, client, subject string) (rs []flow.AcceptOAuth2ConsentRequest, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FindGrantedAndRememberedConsentRequests")
defer span.End()
Expand Down
4 changes: 2 additions & 2 deletions persistence/sql/persister_nid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ func (s *PersisterTestSuite) TestConfirmLoginSession() {
require.NoError(t, r.Persister().CreateLoginSession(s.t1, ls))

// Expects the login session to be confirmed in the correct context.
require.NoError(t, r.Persister().ConfirmLoginSession(s.t1, ls, ls.ID, time.Now(), ls.Subject, !ls.Remember))
require.NoError(t, r.Persister().ConfirmLoginSession(s.t1, ls, ls.ID, time.Now().UTC(), ls.Subject, !ls.Remember))
actual := &flow.LoginSession{}
require.NoError(t, r.Persister().Connection(context.Background()).Find(actual, ls.ID))
exp, _ := json.Marshal(ls)
Expand All @@ -199,7 +199,7 @@ func (s *PersisterTestSuite) TestConfirmLoginSession() {

// Can't find the login session in the wrong context.
require.ErrorIs(t,
r.Persister().ConfirmLoginSession(s.t2, ls, ls.ID, time.Now(), ls.Subject, !ls.Remember),
r.Persister().ConfirmLoginSession(s.t2, ls, ls.ID, time.Now().UTC(), ls.Subject, !ls.Remember),
x.ErrNotFound,
)
})
Expand Down

0 comments on commit 802e3a3

Please sign in to comment.