diff --git a/shadowstream/packet.go b/shadowstream/packet.go index 3f6cc445..b266f581 100644 --- a/shadowstream/packet.go +++ b/shadowstream/packet.go @@ -46,20 +46,19 @@ func Unpack(dst, pkt []byte, s Cipher) ([]byte, error) { type packetConn struct { net.PacketConn Cipher - wbuf []byte - rbuf []byte + buf []byte sync.Mutex // write lock } // NewPacketConn wraps a net.PacketConn with stream cipher encryption/decryption. func NewPacketConn(c net.PacketConn, ciph Cipher) net.PacketConn { - return &packetConn{PacketConn: c, Cipher: ciph, wbuf: make([]byte, 64*1024), rbuf: make([]byte, 64*1024)} + return &packetConn{PacketConn: c, Cipher: ciph, buf: make([]byte, 64*1024)} } func (c *packetConn) WriteTo(b []byte, addr net.Addr) (int, error) { c.Lock() defer c.Unlock() - buf, err := Pack(c.wbuf, b, c.Cipher) + buf, err := Pack(c.buf, b, c.Cipher) if err != nil { return 0, err } @@ -68,10 +67,14 @@ func (c *packetConn) WriteTo(b []byte, addr net.Addr) (int, error) { } func (c *packetConn) ReadFrom(b []byte) (int, net.Addr, error) { - n, addr, err := c.PacketConn.ReadFrom(c.rbuf) + n, addr, err := c.PacketConn.ReadFrom(b) if err != nil { return n, addr, err } - b, err = Unpack(b, c.rbuf[:n], c.Cipher) + bb, err := Unpack(b[c.IVSize():], b[:n], c.Cipher) + if err != nil { + return n, addr, err + } + copy(b, bb) return len(b), addr, err }