Skip to content

Commit

Permalink
Merge pull request #670 from ackleymi/prevent-store-reset-unauth
Browse files Browse the repository at this point in the history
Check logon auth before resetting store
  • Loading branch information
ackleymi authored Sep 10, 2024
2 parents fa2e438 + bb3e854 commit 2903198
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 8 deletions.
4 changes: 2 additions & 2 deletions in_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func (state inSession) Timeout(session *session, event internal.Event) (nextStat
}

func (state inSession) handleLogout(session *session, msg *Message) (nextState sessionState) {
if err := session.verifySelect(msg, false, false); err != nil {
if err := session.verifySelect(msg, false, false, true); err != nil {
return state.processReject(session, msg, err)
}

Expand Down Expand Up @@ -154,7 +154,7 @@ func (state inSession) handleSequenceReset(session *session, msg *Message) (next
}
}

if err := session.verifySelect(msg, bool(gapFillFlag), bool(gapFillFlag)); err != nil {
if err := session.verifySelect(msg, bool(gapFillFlag), bool(gapFillFlag), true); err != nil {
return state.processReject(session, msg, err)
}

Expand Down
22 changes: 22 additions & 0 deletions logon_state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,27 @@ func (s *LogonStateTestSuite) TestFixMsgInLogonInitiateLogonExpectResetSeqNum()
s.NextSenderMsgSeqNum(2)
}

func (s *LogonStateTestSuite) TestFixMsgInLogonInitiateLogonRejectedSeqNumNotReset() {
s.session.InitiateLogon = true
s.session.sentReset = true
s.Require().Nil(s.store.IncrNextSenderMsgSeqNum())

logon := s.Logon()
logon.Body.SetField(tagHeartBtInt, FIXInt(32))
logon.Body.SetField(tagResetSeqNumFlag, FIXBoolean(true))

s.MockApp.On("FromAdmin").Return(RejectLogon{"reject message"})
s.MockApp.On("OnLogout")
s.MockApp.On("ToAdmin")
s.fixMsgIn(s.session, logon)

s.MockApp.AssertExpectations(s.T())
s.State(latentState{})

s.NextTargetMsgSeqNum(2)
s.NextSenderMsgSeqNum(3)
}

func (s *LogonStateTestSuite) TestFixMsgInLogonInitiateLogonUnExpectedResetSeqNum() {
s.session.InitiateLogon = true
s.session.sentReset = false
Expand Down Expand Up @@ -358,6 +379,7 @@ func (s *LogonStateTestSuite) TestFixMsgInLogonSeqNumTooLow() {
logon.Body.SetField(tagHeartBtInt, FIXInt(32))
logon.Header.SetInt(tagMsgSeqNum, 1)

s.MockApp.On("FromAdmin").Return(nil)
s.MockApp.On("ToAdmin")
s.NextTargetMsgSeqNum(2)
s.fixMsgIn(s.session, logon)
Expand Down
27 changes: 21 additions & 6 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,13 @@ func (s *session) handleLogon(msg *Message) error {
}
}

nextSenderMsgNumAtLogonReceived := s.store.NextSenderMsgSeqNum()

// Make sure this is a valid session before resetting the store.
if err := s.verifyMsgAgainstAppImpl(msg); err != nil {
return err
}

var resetSeqNumFlag FIXBoolean
if err := msg.Body.GetField(tagResetSeqNumFlag, &resetSeqNumFlag); err == nil {
if resetSeqNumFlag {
Expand All @@ -517,14 +524,14 @@ func (s *session) handleLogon(msg *Message) error {
}
}

nextSenderMsgNumAtLogonReceived := s.store.NextSenderMsgSeqNum()

if resetStore {
if err := s.store.Reset(); err != nil {
return err
}
}

// Verify seq num too high but dont check against app implementation since we just did that.
// Don't need to double check.
if err := s.verifyIgnoreSeqNumTooHigh(msg); err != nil {
return err
}
Expand Down Expand Up @@ -586,18 +593,18 @@ func (s *session) initiateLogoutInReplyTo(reason string, inReplyTo *Message) (er
}

func (s *session) verify(msg *Message) MessageRejectError {
return s.verifySelect(msg, true, true)
return s.verifySelect(msg, true, true, true)
}

func (s *session) verifyIgnoreSeqNumTooHigh(msg *Message) MessageRejectError {
return s.verifySelect(msg, false, true)
return s.verifySelect(msg, false, true, false)
}

func (s *session) verifyIgnoreSeqNumTooHighOrLow(msg *Message) MessageRejectError {
return s.verifySelect(msg, false, false)
return s.verifySelect(msg, false, false, true)
}

func (s *session) verifySelect(msg *Message, checkTooHigh bool, checkTooLow bool) MessageRejectError {
func (s *session) verifySelect(msg *Message, checkTooHigh bool, checkTooLow bool, checkAppImpl bool) MessageRejectError {
if reject := s.checkBeginString(msg); reject != nil {
return reject
}
Expand Down Expand Up @@ -626,6 +633,14 @@ func (s *session) verifySelect(msg *Message, checkTooHigh bool, checkTooLow bool
}
}

if checkAppImpl {
return s.verifyMsgAgainstAppImpl(msg)
}

return nil
}

func (s *session) verifyMsgAgainstAppImpl(msg *Message) MessageRejectError {
if s.Validator != nil {
if reject := s.Validator.Validate(msg); reject != nil {
return reject
Expand Down

0 comments on commit 2903198

Please sign in to comment.