diff --git a/client.go b/client.go index 2a9051b6..a15b51a9 100644 --- a/client.go +++ b/client.go @@ -77,9 +77,9 @@ type client struct { stop chan struct{} persist Store options ClientOptions - lastSent time.Time - lastReceived time.Time - pingOutstanding bool + lastSent int64 + lastReceived int64 + pingOutstanding int32 status uint32 workers sync.WaitGroup } @@ -239,8 +239,8 @@ func (c *client) Connect() Token { c.stop = make(chan struct{}) if c.options.KeepAlive != 0 { - c.lastReceived = time.Now() - c.lastSent = time.Now() + atomic.StoreInt64(&c.lastReceived, time.Now().Unix()) + atomic.StoreInt64(&c.lastSent, time.Now().Unix()) c.workers.Add(1) go keepalive(c) } @@ -342,9 +342,9 @@ func (c *client) reconnect() { } if c.options.KeepAlive != 0 { - c.pingOutstanding = false - c.lastReceived = time.Now() - c.lastSent = time.Now() + atomic.StoreInt32(&c.pingOutstanding, 0) + atomic.StoreInt64(&c.lastReceived, time.Now().Unix()) + atomic.StoreInt64(&c.lastSent, time.Now().Unix()) c.workers.Add(1) go keepalive(c) } diff --git a/message.go b/message.go index b1b71648..720df5be 100644 --- a/message.go +++ b/message.go @@ -98,7 +98,7 @@ func newConnectMsgFromOptions(options *ClientOptions) *packets.ConnectPacket { } } - m.Keepalive = uint16(options.KeepAlive.Seconds()) + m.Keepalive = uint16(options.KeepAlive) return m } diff --git a/net.go b/net.go index 56c7b180..9fd33ec6 100644 --- a/net.go +++ b/net.go @@ -22,6 +22,7 @@ import ( "net/url" "os" "reflect" + "sync/atomic" "time" "github.com/eclipse/paho.mqtt.golang/packets" @@ -125,7 +126,7 @@ func incoming(c *client) { case c.ibound <- cp: // Notify keepalive logic that we recently received a packet if c.options.KeepAlive != 0 { - c.lastReceived = time.Now() + atomic.StoreInt64(&c.lastReceived, time.Now().Unix()) } case <-c.stop: // This avoids a deadlock should a message arrive while shutting down. @@ -205,7 +206,7 @@ func outgoing(c *client) { } // Reset ping timer after sending control packet. if c.options.KeepAlive != 0 { - c.lastSent = time.Now() + atomic.StoreInt64(&c.lastSent, time.Now().Unix()) } } } @@ -228,7 +229,7 @@ func alllogic(c *client) { switch m := msg.(type) { case *packets.PingrespPacket: DEBUG.Println(NET, "received pingresp") - c.pingOutstanding = false + atomic.StoreInt32(&c.pingOutstanding, 0) case *packets.SubackPacket: DEBUG.Println(NET, "received suback, id:", m.MessageID) token := c.getToken(m.MessageID) diff --git a/options.go b/options.go index 956772c4..6acb5559 100644 --- a/options.go +++ b/options.go @@ -52,7 +52,7 @@ type ClientOptions struct { ProtocolVersion uint protocolVersionExplicit bool TLSConfig tls.Config - KeepAlive time.Duration + KeepAlive int64 PingTimeout time.Duration ConnectTimeout time.Duration MaxReconnectInterval time.Duration @@ -90,7 +90,7 @@ func NewClientOptions() *ClientOptions { ProtocolVersion: 0, protocolVersionExplicit: false, TLSConfig: tls.Config{}, - KeepAlive: 30 * time.Second, + KeepAlive: 30, PingTimeout: 10 * time.Second, ConnectTimeout: 30 * time.Second, MaxReconnectInterval: 10 * time.Minute, @@ -182,7 +182,7 @@ func (o *ClientOptions) SetStore(s Store) *ClientOptions { // allow the client to know that a connection has not been lost with the // server. func (o *ClientOptions) SetKeepAlive(k time.Duration) *ClientOptions { - o.KeepAlive = k + o.KeepAlive = int64(k / time.Second) return o } diff --git a/options_reader.go b/options_reader.go index aab6e453..81674fba 100644 --- a/options_reader.go +++ b/options_reader.go @@ -102,7 +102,7 @@ func (r *ClientOptionsReader) TLSConfig() tls.Config { } func (r *ClientOptionsReader) KeepAlive() time.Duration { - s := r.options.KeepAlive + s := time.Duration(r.options.KeepAlive * int64(time.Second)) return s } diff --git a/ping.go b/ping.go index f9f05fdf..6df939ef 100644 --- a/ping.go +++ b/ping.go @@ -16,6 +16,7 @@ package mqtt import ( "errors" + "sync/atomic" "time" "github.com/eclipse/paho.mqtt.golang/packets" @@ -24,16 +25,16 @@ import ( func keepalive(c *client) { defer c.workers.Done() DEBUG.Println(PNG, "keepalive starting") - var checkInterval time.Duration + var checkInterval int64 var pingSent time.Time - if c.options.KeepAlive > 10*time.Second { - checkInterval = 5 * time.Second + if c.options.KeepAlive > 10 { + checkInterval = 5 } else { checkInterval = c.options.KeepAlive / 2 } - intervalTicker := time.NewTicker(checkInterval) + intervalTicker := time.NewTicker(time.Duration(checkInterval * int64(time.Second))) defer intervalTicker.Stop() for { @@ -42,19 +43,20 @@ func keepalive(c *client) { DEBUG.Println(PNG, "keepalive stopped") return case <-intervalTicker.C: - if time.Now().Sub(c.lastSent) >= c.options.KeepAlive || time.Now().Sub(c.lastReceived) >= c.options.KeepAlive { - if !c.pingOutstanding { + DEBUG.Println(PNG, "ping check", time.Now().Unix()-atomic.LoadInt64(&c.lastSent)) + if time.Now().Unix()-atomic.LoadInt64(&c.lastSent) >= c.options.KeepAlive || time.Now().Unix()-atomic.LoadInt64(&c.lastReceived) >= c.options.KeepAlive { + if atomic.LoadInt32(&c.pingOutstanding) == 0 { DEBUG.Println(PNG, "keepalive sending ping") ping := packets.NewControlPacket(packets.Pingreq).(*packets.PingreqPacket) //We don't want to wait behind large messages being sent, the Write call //will block until it it able to send the packet. - c.pingOutstanding = true + atomic.StoreInt32(&c.pingOutstanding, 1) ping.Write(c.conn) - c.lastSent = time.Now() + atomic.StoreInt64(&c.lastSent, time.Now().Unix()) pingSent = time.Now() } } - if c.pingOutstanding && time.Now().Sub(pingSent) >= c.options.PingTimeout { + if atomic.LoadInt32(&c.pingOutstanding) > 0 && time.Now().Sub(pingSent) >= c.options.PingTimeout { CRITICAL.Println(PNG, "pingresp not received, disconnecting") c.errors <- errors.New("pingresp not received, disconnecting") return diff --git a/unit_options_test.go b/unit_options_test.go index be6ad8b4..e1483565 100644 --- a/unit_options_test.go +++ b/unit_options_test.go @@ -36,7 +36,7 @@ func Test_NewClientOptions_default(t *testing.T) { t.Fatalf("bad default password") } - if o.KeepAlive != 30*time.Second { + if o.KeepAlive != 30 { t.Fatalf("bad default timeout") } } @@ -69,7 +69,7 @@ func Test_NewClientOptions_mix(t *testing.T) { t.Fatalf("bad set password") } - if o.KeepAlive != 88000000000 { + if o.KeepAlive != 88 { t.Fatalf("bad set timeout: %d", o.KeepAlive) } }