Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IMPROVED] MQTT error message when client connects with websocket #2151

Merged
merged 1 commit into from
Apr 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions server/mqtt.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,13 @@ var (
mqttFlapCleanItvl = mqttSessFlappingCleanupInterval
)

var (
errMQTTWebsocketNotSupported = errors.New("invalid connection, websocket currently not supported")
errMQTTTopicFilterCannotBeEmpty = errors.New("topic filter cannot be empty")
errMQTTMalformedVarInt = errors.New("malformed variable int")
errMQTTSecondConnectPacket = errors.New("received a second CONNECT packet")
)

type srvMQTT struct {
listener net.Listener
listenerErr error
Expand Down Expand Up @@ -565,7 +572,13 @@ func (c *client) mqttParse(buf []byte) error {
// If client was not connected yet, the first packet must be
// a mqttPacketConnect otherwise we fail the connection.
if !connected && pt != mqttPacketConnect {
err = errors.New("not connected")
// Try to guess if the client is trying to connect using Websocket,
// which is currently not supported
if bytes.HasPrefix(buf, []byte("GET ")) {
err = errMQTTWebsocketNotSupported
} else {
err = fmt.Errorf("the first packet should be a CONNECT (%v), got %v", mqttPacketConnect, pt)
}
break
}

Expand Down Expand Up @@ -647,7 +660,7 @@ func (c *client) mqttParse(buf []byte) error {
case mqttPacketConnect:
// It is an error to receive a second connect packet
if connected {
err = errors.New("second connect packet")
err = errMQTTSecondConnectPacket
break
}
var rc byte
Expand Down Expand Up @@ -2913,7 +2926,7 @@ func (c *client) mqttParseSubsOrUnsubs(r *mqttReader, b byte, pl int, sub bool)
return 0, nil, err
}
if len(topic) == 0 {
return 0, nil, errors.New("topic filter cannot be empty")
return 0, nil, errMQTTTopicFilterCannotBeEmpty
}
// Spec [MQTT-3.8.3-1], [MQTT-3.10.3-1]
if !utf8.Valid(topic) {
Expand Down Expand Up @@ -3648,7 +3661,7 @@ func (r *mqttReader) readPacketLen() (int, error) {
}
m *= 0x80
if m > 0x200000 {
return 0, errors.New("malformed variable int")
return 0, errMQTTMalformedVarInt
}
}
}
Expand Down
55 changes: 53 additions & 2 deletions server/mqtt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import (
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"strings"
"sync"
Expand Down Expand Up @@ -1381,6 +1383,9 @@ func TestMQTTConnectNotFirstPacket(t *testing.T) {
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)

l := &captureErrorLogger{errCh: make(chan string, 10)}
s.SetLogger(l, false, false)

c, err := net.Dial("tcp", fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port))
if err != nil {
t.Fatalf("Error on dial: %v", err)
Expand All @@ -1393,6 +1398,15 @@ func TestMQTTConnectNotFirstPacket(t *testing.T) {
t.Fatalf("Error publishing: %v", err)
}
testMQTTExpectDisconnect(t, c)

select {
case err := <-l.errCh:
if !strings.Contains(err, "should be a CONNECT") {
t.Fatalf("Expected error about first packet being a CONNECT, got %v", err)
}
case <-time.After(time.Second):
t.Fatal("Did not log any error")
}
}

func TestMQTTSecondConnect(t *testing.T) {
Expand Down Expand Up @@ -1691,7 +1705,7 @@ func TestMQTTParseSub(t *testing.T) {
{"error reading packet id", []byte{1}, mqttSubscribeFlags, 1, eofr, "reading packet identifier"},
{"missing filters", []byte{0, 1}, mqttSubscribeFlags, 2, nil, "subscribe protocol must contain at least 1 topic filter"},
{"error reading topic", []byte{0, 1, 0, 2, 'a'}, mqttSubscribeFlags, 5, eofr, "topic filter"},
{"empty topic", []byte{0, 1, 0, 0}, mqttSubscribeFlags, 4, nil, "topic filter cannot be empty"},
{"empty topic", []byte{0, 1, 0, 0}, mqttSubscribeFlags, 4, nil, errMQTTTopicFilterCannotBeEmpty.Error()},
{"invalid utf8 topic", []byte{0, 1, 0, 1, 241}, mqttSubscribeFlags, 5, nil, "invalid utf8 for topic filter"},
{"missing qos", []byte{0, 1, 0, 1, 'a'}, mqttSubscribeFlags, 5, nil, "QoS"},
{"invalid qos", []byte{0, 1, 0, 1, 'a', 3}, mqttSubscribeFlags, 6, nil, "subscribe QoS value must be 0, 1 or 2"},
Expand Down Expand Up @@ -2903,7 +2917,7 @@ func TestMQTTParseUnsub(t *testing.T) {
{"error reading packet id", []byte{1}, mqttUnsubscribeFlags, 1, eofr, "reading packet identifier"},
{"missing filters", []byte{0, 1}, mqttUnsubscribeFlags, 2, nil, "subscribe protocol must contain at least 1 topic filter"},
{"error reading topic", []byte{0, 1, 0, 2, 'a'}, mqttUnsubscribeFlags, 5, eofr, "topic filter"},
{"empty topic", []byte{0, 1, 0, 0}, mqttUnsubscribeFlags, 4, nil, "topic filter cannot be empty"},
{"empty topic", []byte{0, 1, 0, 0}, mqttUnsubscribeFlags, 4, nil, errMQTTTopicFilterCannotBeEmpty.Error()},
{"invalid utf8 topic", []byte{0, 1, 0, 1, 241}, mqttUnsubscribeFlags, 5, nil, "invalid utf8 for topic filter"},
} {
t.Run(test.name, func(t *testing.T) {
Expand Down Expand Up @@ -4526,6 +4540,43 @@ func TestMQTTStreamInfoReturnsNonEmptySubject(t *testing.T) {
}
}

func TestMQTTWebsocketNotSupported(t *testing.T) {
o := testMQTTDefaultOptions()
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)

l := &captureErrorLogger{errCh: make(chan string, 10)}
s.SetLogger(l, false, false)

addr := fmt.Sprintf("%s:%d", o.MQTT.Host, o.MQTT.Port)
wsc, err := net.Dial("tcp", addr)
if err != nil {
t.Fatalf("Error creating connection: %v", err)
}
req := testWSCreateValidReq()
req.URL, _ = url.Parse("ws://" + addr)
if err := req.Write(wsc); err != nil {
t.Fatalf("Error sending request: %v", err)
}
br := bufio.NewReader(wsc)
resp, err := http.ReadResponse(br, req)
if err == nil {
if resp != nil {
defer resp.Body.Close()
}
t.Fatalf("Expected error, got resp=%+v", resp)
}

select {
case err := <-l.errCh:
if !strings.Contains(err, "not supported") {
t.Fatalf("Expected error about websocket not supported, got %v", err)
}
case <-time.After(time.Second):
t.Fatal("Did not log any error")
}
}

//////////////////////////////////////////////////////////////////////////
//
// Benchmarks
Expand Down