diff --git a/ocppj/central_system_test.go b/ocppj/central_system_test.go index 5d6ec7ed..affcc66d 100644 --- a/ocppj/central_system_test.go +++ b/ocppj/central_system_test.go @@ -5,6 +5,7 @@ import ( "fmt" "strconv" "sync" + "testing" "time" "github.com/stretchr/testify/assert" @@ -54,6 +55,20 @@ func (suite *OcppJTestSuite) TestServerStoppedError() { assert.Error(t, err, "ocppj server is not started, couldn't send request") } +func (suite *OcppJTestSuite) TestServerStopBeforeStart() { + t := suite.T() + + // Stop server + suite.mockServer.On("Stop").Return(nil) + suite.centralSystem.Stop() + + // Start server (should return) + suite.mockServer.On("Start", mock.AnythingOfType("int"), mock.AnythingOfType("string")).Return(nil) + suite.centralSystem.Start(8887, "/{ws}") + + assert.True(t, suite.serverDispatcher.IsRunning()) +} + // ----------------- SendRequest tests ----------------- func (suite *OcppJTestSuite) TestCentralSystemSendRequest() { @@ -749,3 +764,9 @@ func (suite *OcppJTestSuite) TestServerRequestFlow() { q, _ = suite.serverRequestMap.Get(mockChargePoint2) assert.True(t, q.IsEmpty()) } + +func TestStopBeforeStart(t *testing.T) { + + s := ocppj.NewServer(nil, nil, nil) + s.Stop() +} diff --git a/ocppj/dispatcher.go b/ocppj/dispatcher.go index f7ef9ab4..0ae07ae5 100644 --- a/ocppj/dispatcher.go +++ b/ocppj/dispatcher.go @@ -381,6 +381,7 @@ func NewDefaultServerDispatcher(queueMap ServerQueueMap) *DefaultServerDispatche requestChannel: nil, readyForDispatch: make(chan string, 1), timeout: defaultMessageTimeout, + stoppedC: make(chan struct{}, 1), } d.pendingRequestState = NewServerState(&d.mutex) return d @@ -389,7 +390,6 @@ func NewDefaultServerDispatcher(queueMap ServerQueueMap) *DefaultServerDispatche func (d *DefaultServerDispatcher) Start() { d.requestChannel = make(chan string, 20) d.timerC = make(chan string, 10) - d.stoppedC = make(chan struct{}, 1) d.running = true go d.messagePump() } @@ -404,7 +404,12 @@ func (d *DefaultServerDispatcher) Stop() { d.mutex.Lock() defer d.mutex.Unlock() d.running = false - close(d.stoppedC) + + select { + case <-d.stoppedC: + default: + close(d.stoppedC) + } } func (d *DefaultServerDispatcher) SetTimeout(timeout time.Duration) { diff --git a/ocppj/server.go b/ocppj/server.go index 6aa0f91e..079d3d90 100644 --- a/ocppj/server.go +++ b/ocppj/server.go @@ -67,7 +67,8 @@ func NewServer(wsServer ws.WsServer, dispatcher ServerDispatcher, stateHandler S server: wsServer, RequestState: stateHandler, dispatcher: dispatcher, - stopped: make(chan struct{})} + stopped: make(chan struct{}), + } for _, profile := range profiles { s.AddProfile(profile) } @@ -148,7 +149,11 @@ func (s *Server) Start(listenPort int, listenPath string) { // Stops the server. // This clears all pending requests and causes the Start function to return. func (s *Server) Stop() { - close(s.stopped) + select { + case <-s.stopped: + default: + close(s.stopped) + } s.waitGroup.Wait() s.server.Stop() s.dispatcher.Stop()