diff --git a/stack/socket_tcp.go b/stack/socket_tcp.go index e0e10ad..54a7cc4 100644 --- a/stack/socket_tcp.go +++ b/stack/socket_tcp.go @@ -149,6 +149,11 @@ func (t *TCPSocket) handleRecv(response []byte, pkt *TCPPacket) (n int, err erro // if segIncoming.SEQ != t.scb.RecvNext() { // return 0, ErrDroppedPacket // SCB does not admit out-of-order packets. // } + + t.scb.Recv(segIncoming) + // if err != nil { + // return 0, err + // } if segIncoming.Flags.HasAny(seqs.FlagPSH) { if len(payload) != int(segIncoming.DATALEN) { return 0, errors.New("segment data length does not match payload length") @@ -163,11 +168,9 @@ func (t *TCPSocket) handleRecv(response []byte, pkt *TCPPacket) (n int, err erro return 0, err } } - err = t.scb.Recv(segIncoming) - if err != nil { - return 0, err + if t.tx.Buffered() > 0 { + return t.handleUser(response, pkt) // Yield to handleUser. } - segOut, ok := t.scb.PendingSegment(0) if !ok { return 0, nil // No pending control segment. Yield to handleUser. @@ -294,7 +297,7 @@ func (r *ring) Read(b []byte) (int, error) { if r.end >= r.off { // start off end len(buf) // | sfree | used | efree | - n := copy(b, r.buf[r.off:]) + n := copy(b, r.buf[r.off:r.end]) r.off += n r.onReadEnd() return n, nil @@ -304,7 +307,7 @@ func (r *ring) Read(b []byte) (int, error) { n := copy(b, r.buf[r.off:]) r.off += n if n < len(b) { - n2 := copy(b[n:], r.buf) + n2 := copy(b[n:], r.buf[:r.end]) r.off = n2 n += n2 } diff --git a/stack/stack_test.go b/stack/stack_test.go index 9e6c61c..44f2442 100644 --- a/stack/stack_test.go +++ b/stack/stack_test.go @@ -36,7 +36,7 @@ func TestStackEstablish(t *testing.T) { } } -func TestStackSendReceive(t *testing.T) { +func TestStackSendReceive_simplex(t *testing.T) { client, server := createTCPClientServerPair(t) // 3 way handshake needs2 exchanges to complete. @@ -47,23 +47,42 @@ func TestStackSendReceive(t *testing.T) { // Send data from client to server. const data = "hello world" - err := client.Send([]byte(data)) - if err != nil { - t.Fatal(err) + socketSendString(client, data) + txStacks(t, 1, client.PortStack(), server.PortStack()) + if client.State() != seqs.StateEstablished || server.State() != seqs.StateEstablished { + t.Fatal("not established") + } + got := socketReadAllString(server) + if got != data { + t.Error("got", got, "want", data) } +} - txStacks(t, 1, client.PortStack(), server.PortStack()) +func TestStackSendReceive_duplex(t *testing.T) { + client, server := createTCPClientServerPair(t) + cstack, sstack := client.PortStack(), server.PortStack() + // 3 way handshake needs2 exchanges to complete. + txStacks(t, 2, cstack, sstack) if client.State() != seqs.StateEstablished || server.State() != seqs.StateEstablished { t.Fatal("not established") } - var buf [len(data)]byte - n, err := server.Recv(buf[:]) - if err != nil { - t.Fatal(err) + // Send data from client to server. + const data = "hello world" + socketSendString(client, data) + socketSendString(server, data) + tx, bytes := txStacks(t, 2, cstack, sstack) + if client.State() != seqs.StateEstablished || server.State() != seqs.StateEstablished { + t.Fatal("not established") } - if string(buf[:n]) != data { - t.Error("got", string(buf[:n]), "want", data) + t.Logf("tx=%d bytes=%d", tx, bytes) + clientstr := socketReadAllString(client) + serverstr := socketReadAllString(server) + if clientstr != data { + t.Error("got", clientstr, "want", data) + } + if serverstr != data { + t.Error("got", serverstr, "want", data) } } @@ -173,3 +192,23 @@ func createPortStacks(t *testing.T, n int) (stacks []*stack.PortStack) { } return stacks } + +func socketReadAllString(s *stack.TCPSocket) string { + var str strings.Builder + var buf [1024]byte + for { + n, err := s.Recv(buf[:]) + str.Write(buf[:n]) + if n == 0 || err != nil { + break + } + } + return str.String() +} + +func socketSendString(s *stack.TCPSocket, str string) { + err := s.Send([]byte(str)) + if err != nil { + panic(err) + } +}