Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 133 additions & 8 deletions alg_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@ package alg

import (
"fmt"
"io"
"os"
"syscall"

"golang.org/x/sys/unix"
)

const defaultSocketBufferSize = 64 * 1024

// A conn is the internal connection type for Linux.
type conn struct {
s socket
Expand All @@ -26,7 +30,7 @@ type socket interface {
Sendto(p []byte, flags int, to unix.Sockaddr) error
}

// dial is the entry point for Dial. dial opens an AF_ALG socket
// dial is the entry point for Dial. dial opens an AF_ALG socket
// using system calls.
func dial(typ, name string, config *Config) (*conn, error) {
fd, err := unix.Socket(unix.AF_ALG, unix.SOCK_SEQPACKET, 0)
Expand Down Expand Up @@ -103,20 +107,129 @@ func (h *ihash) Close() error {
return h.s.Close()
}

func (h *ihash) ReadFrom(r io.Reader) (int64, error) {
if f, ok := r.(*os.File); ok {
if w, err, handled := h.sendfile(f, -1); handled {
return w, err
}
if w, err, handled := h.splice(f, -1); handled {
return w, err
}
}
if lr, ok := r.(*io.LimitedReader); ok {
return h.readFromLimitedReader(lr)
}
return genericReadFrom(h, r)
}

func (h *ihash) readFromLimitedReader(lr *io.LimitedReader) (int64, error) {
if f, ok := lr.R.(*os.File); ok {
if w, err, handled := h.sendfile(f, lr.N); handled {
return w, err
}
if w, err, handled := h.splice(f, lr.N); handled {
return w, err
}
}
return genericReadFrom(h, lr)
}

func (h *ihash) splice(f *os.File, remain int64) (written int64, err error, handled bool) {
offset, err := f.Seek(0, io.SeekCurrent)
if err != nil {
return 0, nil, false
}
fi, err := f.Stat()
if err != nil {
return 0, nil, false
}
if remain == -1 {
remain = fi.Size() - offset
}
// mmap must align on a page boundary
// mmap from 0, use data from offset
mmap, err := syscall.Mmap(int(f.Fd()), 0, int(fi.Size()),
syscall.PROT_READ, syscall.MAP_SHARED)
if err != nil {
return 0, nil, false
}
defer syscall.Munmap(mmap)
bytes := mmap[offset : offset+remain]
var (
total = len(bytes)
start = 0
end = defaultSocketBufferSize
)

if end > total {
end = total
}
for {
n, err := h.Write(bytes[start:end])
if err != nil {
return int64(start + n), err, true
}
start += n
if start >= total {
break
}
end += n
if end > total {
end = total
}
}
return remain, nil, true
}

func (h *ihash) sendfile(f *os.File, remain int64) (written int64, err error, handled bool) {
offset, err := f.Seek(0, io.SeekCurrent)
if err != nil {
return 0, nil, false
}
fi, err := f.Stat()
if err != nil {
return 0, nil, false
}
if remain == -1 {
remain = fi.Size() - offset
}
sc, err := f.SyscallConn()
if err != nil {
return 0, nil, false
}
var (
n int
werr error
)
err = sc.Read(func(fd uintptr) bool {
for {
n, werr = syscall.Sendfile(h.s.FD(), int(fd), &offset, int(remain))
written += int64(n)
if werr != nil {
break
}
if int64(n) >= remain {
break
}
remain -= int64(n)
}
return true
})
if err == nil {
err = werr
}
return written, err, true
}

// Write writes data to an AF_ALG socket, but instructs the kernel
// not to finalize the hash.
func (h *ihash) Write(b []byte) (int, error) {
n, err := h.pipes[1].Vmsplice(b, 0)
if err != nil {
return 0, err
return n, err
}

_, err = h.pipes[0].Splice(h.s.FD(), n, unix.SPLICE_F_MOVE|unix.SPLICE_F_MORE)
if err != nil {
return 0, err
}

return len(b), nil
return n, err
}

// Sum reads data from an AF_ALG socket, and appends it to the input
Expand Down Expand Up @@ -187,6 +300,7 @@ type sysPipe struct {
func (p *sysPipe) Splice(out, size, flags int) (int64, error) {
return unix.Splice(p.fd, nil, out, nil, size, flags)
}

func (p *sysPipe) Vmsplice(b []byte, flags int) (int, error) {
iov := unix.Iovec{
Base: &b[0],
Expand All @@ -199,3 +313,14 @@ func (p *sysPipe) Vmsplice(b []byte, flags int) (int, error) {
flags,
)
}

type writerOnly struct {
io.Writer
}

// Fallback implementation of io.ReaderFrom's ReadFrom, when os.File isn't
// applicable.
func genericReadFrom(w io.Writer, r io.Reader) (n int64, err error) {
// Use wrapper to hide existing r.ReadFrom from io.Copy.
return io.Copy(writerOnly{w}, r)
}
49 changes: 14 additions & 35 deletions alg_linux_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,47 +8,41 @@ import (
"crypto/sha1"
"crypto/sha256"
"encoding/hex"
"flag"
"fmt"
"hash"
"io"
"log"
"os"
"testing"

"github.com/mdlayher/alg"
)

const MB = (1 << 20)
const MB = 1 << 20

var buf = bytes.Repeat([]byte("a"), 512*MB)

// Flags to specify using either stdlib or AF_ALG transformations.
var (
flagBenchSTD = flag.Bool("bench.std", false, "benchmark only standard library transformations")
flagBenchALG = flag.Bool("bench.alg", false, "benchmark only AF_ALG transformations")
)

func init() {
flag.Parse()
}
//var (
// flagBenchSTD = flag.Bool("bench.std", false, "benchmark only standard library transformations")
// flagBenchALG = flag.Bool("bench.alg", false, "benchmark only AF_ALG transformations")
//)

func TestMD5Equal(t *testing.T) {
const expect = "0829f71740aab1ab98b33eae21dee122"
const expect = "221994040b14294bdf7fbc128e66633c"
withHash(t, "md5", func(algh hash.Hash) {
testHashEqual(t, expect, md5.New(), algh)
})
}

func TestSHA1Equal(t *testing.T) {
const expect = "0631457264ff7f8d5fb1edc2c0211992a67c73e6"
const expect = "2727756cfee3fbfe24bf5650123fd7743d7b3465"
withHash(t, "sha1", func(algh hash.Hash) {
testHashEqual(t, expect, sha1.New(), algh)
})
}

func TestSHA256Equal(t *testing.T) {
const expect = "9f1dcbc35c350d6027f98be0f5c8b43b42ca52b7604459c0c42be3aa88913d47"
const expect = "dd4e6730520932767ec0a9e33fe19c4ce24399d6eba4ff62f13013c9ed30ef87"
withHash(t, "sha256", func(algh hash.Hash) {
testHashEqual(t, expect, sha256.New(), algh)
})
Expand Down Expand Up @@ -89,7 +83,6 @@ func testHashEqual(t *testing.T, expect string, stdh, algh hash.Hash) {

cb := stdh.Sum(nil)
ab := algh.Sum(nil)
log.Printf("%x\n%x", cb, ab)

if want, got := cb, ab; !bytes.Equal(want, got) {
t.Fatalf("unexpected hash sum:\n- std: %x\n- alg: %x", want, got)
Expand Down Expand Up @@ -124,26 +117,12 @@ func benchmarkHashes(b *testing.B, stdh, algh hash.Hash) {
for _, size := range sizes {
for _, page := range pages {
name := fmt.Sprintf("%dMB/%dpages", size, page)
switch {
case *flagBenchSTD && *flagBenchALG:
b.Fatal("cannot specify both '-bench.std' and '-bench.alg'")
case *flagBenchSTD:
b.Run(name, func(b *testing.B) {
benchmarkHash(b, size*MB, page, stdh)
})
case *flagBenchALG:
b.Run(name, func(b *testing.B) {
benchmarkHash(b, size*MB, page, algh)
})
default:
b.Run(name+"/std", func(b *testing.B) {
benchmarkHash(b, size*MB, page, stdh)
})

b.Run(name+"/alg", func(b *testing.B) {
benchmarkHash(b, size*MB, page, algh)
})
}
b.Run(name, func(b *testing.B) {
benchmarkHash(b, size*MB, page, stdh)
})
b.Run(name, func(b *testing.B) {
benchmarkHash(b, size*MB, page, algh)
})
}
}
}
Expand Down
72 changes: 30 additions & 42 deletions alg_linux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,57 +3,41 @@
package alg

import (
"bytes"
"reflect"
"encoding/hex"
"testing"

"golang.org/x/sys/unix"
)

func TestLinuxConn_bind(t *testing.T) {
addr := &unix.SockaddrALG{
Type: "hash",
Name: "sha1",
}

s := &testSocket{}
if _, err := bind(s, addr); err != nil {
t.Fatalf("failed to bind: %v", err)
}

if want, got := addr, s.bind; !reflect.DeepEqual(want, got) {
t.Fatalf("unexpected bind address:\n- want: %#v\n- got: %#v",
want, got)
}
}
//func TestLinuxConn_bind(t *testing.T) {
// addr := &unix.SockaddrALG{
// Type: "hash",
// Name: "sha1",
// }
//
// s := &testSocket{}
// if _, err := bind(s, addr); err != nil {
// t.Fatalf("failed to bind: %v", err)
// }
//
// if want, got := addr, s.bind; !reflect.DeepEqual(want, got) {
// t.Fatalf("unexpected bind address:\n- want: %#v\n- got: %#v",
// want, got)
// }
//}

func TestLinuxConnWrite(t *testing.T) {
addr := &unix.SockaddrALG{
Type: "hash",
Name: "sha1",
}

h, s := testLinuxHash(t, addr)
h, _ := testLinuxHash(t, addr)

b := []byte("hello world")
if _, err := h.Write(b); err != nil {
t.Fatalf("failed to write: %v", err)
}

if want, got := b, s.sendto.p; !bytes.Equal(want, got) {
t.Fatalf("unexpected sendto bytes:\n- want: %v\n- got: %v",
want, got)
}

if want, got := unix.MSG_MORE, s.sendto.flags; want != got {
t.Fatalf("unexpected sendto flags:\n- want: %v\n- got: %v",
want, got)
}

if want, got := addr, s.sendto.to; !reflect.DeepEqual(want, got) {
t.Fatalf("unexpected sendto addr:\n- want: %v\n- got: %v",
want, got)
}
}

func TestLinuxConnSum(t *testing.T) {
Expand All @@ -62,20 +46,24 @@ func TestLinuxConnSum(t *testing.T) {
Name: "sha1",
}

h, s := testLinuxHash(t, addr)
s.read = []byte("deadbeef")
h, _ := testLinuxHash(t, addr)

sum := h.Sum([]byte("foo"))
sum := h.Sum(nil)
hex.EncodeToString(sum)

if want, got := []byte("foodeadbeef"), sum; !bytes.Equal(want, got) {
if want, got := "da39a3ee5e6b4b0d3255bfef95601890afd80709", hex.EncodeToString(sum); want != got {
t.Fatalf("unexpected sum bytes:\n- want: %v\n- got: %v",
want, got)
}
}

func testLinuxHash(t *testing.T, addr *unix.SockaddrALG) (Hash, *testSocket) {
s := &testSocket{}
c, err := bind(s, addr)
func testLinuxHash(t *testing.T, addr *unix.SockaddrALG) (Hash, *sysSocket) {
fd, err := unix.Socket(unix.AF_ALG, unix.SOCK_SEQPACKET, 0)
if err != nil {
t.Fatalf("failed to create socket: %v", err)
}

c, err := bind(&sysSocket{fd: fd}, addr)
if err != nil {
t.Fatalf("failed to bind: %v", err)
}
Expand All @@ -86,7 +74,7 @@ func testLinuxHash(t *testing.T, addr *unix.SockaddrALG) (Hash, *testSocket) {
}

// A little gross, but gets the job done.
return hash, hash.(*ihash).s.(*testSocket)
return hash, hash.(*ihash).s.(*sysSocket)
}

var _ socket = &testSocket{}
Expand Down