Skip to content
This repository was archived by the owner on Jun 2, 2023. It is now read-only.

Commit 66a5ca3

Browse files
committed
handshake: rewrite sequences for clarity and correctness
1 parent 78d64d4 commit 66a5ca3

8 files changed

+274
-178
lines changed

handshake/ack_sequence.go

Lines changed: 0 additions & 79 deletions
This file was deleted.

handshake/ack_sequence_test.go

Lines changed: 0 additions & 63 deletions
This file was deleted.

handshake/client_ack_sequence.go

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
package handshake
2+
3+
import (
4+
"crypto/rand"
5+
"io"
6+
)
7+
8+
// ClientAckSequence is a handshake sequence that verifies the challenge
9+
// sequence sent by the client during a RTMP handshake. It is responsible for
10+
// reading and responding to the C1 packet (with the C2 packet), and sending the
11+
// server challenge in S1.
12+
type ClientAckSequence struct {
13+
C1 *AckPacket
14+
S1 *AckPacket
15+
}
16+
17+
var _ Sequence = new(ClientAckSequence)
18+
19+
// NewClientAckSequence initializes and returns a new *ClientAckSequence
20+
// initialized with an empty C1 packet and a new S1 packet, initialized with the
21+
// result of a "rand.Read" into its Payload header.
22+
func NewClientAckSequence() *ClientAckSequence {
23+
c := &ClientAckSequence{
24+
C1: new(AckPacket),
25+
S1: new(AckPacket),
26+
}
27+
28+
rand.Read(c.S1.Payload[:])
29+
30+
return c
31+
}
32+
33+
// Read implements the Sequence.Read function. It reads the C1 packet and
34+
// returns any read error, if there was one. Otherwise, a value of "nil" is
35+
// returned instead.
36+
func (c *ClientAckSequence) Read(r io.Reader) error {
37+
if err := c.C1.Read(r); err != nil {
38+
return err
39+
}
40+
41+
return nil
42+
}
43+
44+
// Write implements the Sequence.Write function. It writes the S1 packet first
45+
// (returning any errors if there is one), and then writes the S2 packet with
46+
// the same data as was sent in the C1 packet (returning any error that was
47+
// encountered).
48+
//
49+
// A successful call to Write constitutes a value of `nil` being returned.
50+
func (c *ClientAckSequence) Write(w io.Writer) error {
51+
if err := c.S1.Write(w); err != nil {
52+
return err
53+
}
54+
55+
s2 := &AckPacket{
56+
Time1: c.C1.Time1,
57+
Payload: c.C1.Payload,
58+
}
59+
60+
if err := s2.Write(w); err != nil {
61+
return err
62+
}
63+
64+
return nil
65+
}
66+
67+
// Nex implements the Sequence.Next function.
68+
func (c *ClientAckSequence) Next() Sequence {
69+
return NewServerAckSequence(c.S1)
70+
}

handshake/client_ack_sequence_test.go

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
package handshake_test
2+
3+
import (
4+
"bytes"
5+
"crypto/rand"
6+
"io"
7+
"testing"
8+
9+
"github.com/WatchBeam/rtmp/handshake"
10+
"github.com/stretchr/testify/assert"
11+
)
12+
13+
func TestItReadsC1Packet(t *testing.T) {
14+
payload := payload()
15+
16+
buf := bytes.NewBuffer([]byte{})
17+
buf.Write([]byte{0x0, 0x0, 0x0, 0x1}) // Time 1
18+
buf.Write([]byte{0x0, 0x0, 0x0, 0x0}) // Padding
19+
buf.Write(payload[:])
20+
21+
c := handshake.NewClientAckSequence()
22+
err := c.Read(buf)
23+
24+
assert.Nil(t, err)
25+
assert.Equal(t, &handshake.AckPacket{
26+
Time1: 1,
27+
Time2: 0,
28+
Payload: payload,
29+
}, c.C1)
30+
}
31+
32+
func TestItErorrsOnBadC1Packets(t *testing.T) {
33+
buf := bytes.NewBuffer([]byte{
34+
// Empty C1 ~> io.EOF
35+
})
36+
37+
c := handshake.NewClientAckSequence()
38+
err := c.Read(buf)
39+
40+
assert.Equal(t, io.EOF, err)
41+
}
42+
43+
func TestItWritesS1AndMatchingS2(t *testing.T) {
44+
buf := bytes.NewBuffer([]byte{})
45+
46+
c := handshake.NewClientAckSequence()
47+
c.C1 = &handshake.AckPacket{
48+
Payload: payload(),
49+
}
50+
51+
start := 4 + 4 + 1528 + 4 + 4
52+
end := start + 1528
53+
54+
err := c.Write(buf)
55+
56+
assert.Nil(t, err)
57+
assert.Len(t, buf.Bytes(), 2*(4+4+1528))
58+
assert.Equal(t, c.C1.Payload[:], buf.Bytes()[start:end])
59+
}
60+
61+
func payload() [handshake.PayloadLen]byte {
62+
var b [handshake.PayloadLen]byte
63+
64+
rand.Read(b[:])
65+
66+
return b
67+
}

handshake/server_ack_sequence.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
package handshake
2+
3+
import (
4+
"bytes"
5+
"errors"
6+
"io"
7+
)
8+
9+
var (
10+
// MismatchedChallengeErr is an error which is returned in a situtation
11+
// when the challenge sequence read from the client in C2 differs from
12+
// the challenge sequence in S1.
13+
MismatchedChallengeErr = errors.New("rtmp/handshake: mismatched challenge")
14+
)
15+
16+
// ServerAckSequence is a type implementing the handshake.Sequence interface and
17+
// is repsonseible for reading and verifying the C2 packet written by the client
18+
// against the S1 packet from the server.
19+
type ServerAckSequence struct {
20+
// S1 is the packet which C2 should acknowledge.
21+
S1 *AckPacket
22+
}
23+
24+
var _ Sequence = new(ServerAckSequence)
25+
26+
// NewServerAckSequence returns a new *ServerAckSequence initialized with the
27+
// given S1 packet.
28+
func NewServerAckSequence(S1 *AckPacket) *ServerAckSequence {
29+
return &ServerAckSequence{S1}
30+
}
31+
32+
// Read implements the Handshake.Read method by reading the C2 packet and
33+
// comparing it to the stored S1 packet. If a read error occured while reading
34+
// C2, then it will be returned. If the payloads were not equal, then
35+
// MismatchedChallengeErr will be returned. Otherwise, in the successful case,
36+
// a value of nil will be returned.
37+
func (s *ServerAckSequence) Read(r io.Reader) error {
38+
c2 := new(AckPacket)
39+
if err := c2.Read(r); err != nil {
40+
return err
41+
}
42+
43+
if !bytes.Equal(s.S1.Payload[:], c2.Payload[:]) {
44+
return MismatchedChallengeErr
45+
}
46+
47+
return nil
48+
}
49+
50+
// Write implements the Sequence.Write function. Since there is nothing to
51+
// write, a value of nil is always returned here.
52+
func (s *ServerAckSequence) Write(w io.Writer) error { return nil }
53+
54+
// Next implements the Sequence.Next function. Since there is no next function,
55+
// this function always returns nil.
56+
func (s *ServerAckSequence) Next() Sequence { return nil }

0 commit comments

Comments
 (0)