From 10368afa4067375b817f7114b135640b6a4546b5 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Mon, 30 Sep 2024 11:41:02 -0700 Subject: [PATCH] Add `compactfloats` directive (#366) Add `//msgp:compactfloats` file directive, that will store float64 as float32, if it can be done so losslessly. Boring, but correct replacement of https://github.com/tinylib/msgp/pull/365 --- _generated/compactfloats.go | 19 ++++++++ _generated/compactfloats_test.go | 78 ++++++++++++++++++++++++++++++++ gen/decode.go | 5 +- gen/encode.go | 9 ++-- gen/marshal.go | 8 ++-- gen/size.go | 4 +- gen/spec.go | 10 ++-- gen/testgen.go | 4 +- gen/unmarshal.go | 5 +- msgp/write.go | 10 ++++ msgp/write_bytes.go | 10 ++++ msgp/write_bytes_test.go | 69 ++++++++++++++++++++++++++++ parse/directives.go | 15 ++++-- parse/getast.go | 16 ++++--- 14 files changed, 231 insertions(+), 31 deletions(-) create mode 100644 _generated/compactfloats.go create mode 100644 _generated/compactfloats_test.go diff --git a/_generated/compactfloats.go b/_generated/compactfloats.go new file mode 100644 index 00000000..c1e80df0 --- /dev/null +++ b/_generated/compactfloats.go @@ -0,0 +1,19 @@ +package _generated + +//go:generate msgp + +//msgp:compactfloats + +//msgp:ignore F64 +type F64 float64 + +//msgp:replace F64 with:float64 + +type Floats struct { + A float64 + B float32 + Slice []float64 + Map map[string]float64 + F F64 + OE float64 `msg:",omitempty"` +} diff --git a/_generated/compactfloats_test.go b/_generated/compactfloats_test.go new file mode 100644 index 00000000..42376683 --- /dev/null +++ b/_generated/compactfloats_test.go @@ -0,0 +1,78 @@ +package _generated + +import ( + "bytes" + "reflect" + "testing" + + "github.com/tinylib/msgp/msgp" +) + +func TestCompactFloats(t *testing.T) { + // Constant that can be represented in f32 without loss + const f32ok = -1e2 + allF32 := Floats{ + A: f32ok, + B: f32ok, + Slice: []float64{f32ok, f32ok}, + Map: map[string]float64{"a": f32ok}, + F: f32ok, + OE: f32ok, + } + asF32 := float32(f32ok) + wantF32 := map[string]any{"A": asF32, "B": asF32, "F": asF32, "Map": map[string]any{"a": asF32}, "OE": asF32, "Slice": []any{asF32, asF32}} + + enc, err := allF32.MarshalMsg(nil) + if err != nil { + t.Error(err) + } + i, _, _ := msgp.ReadIntfBytes(enc) + got := i.(map[string]any) + if !reflect.DeepEqual(got, wantF32) { + t.Errorf("want: %v, got: %v (diff may be types)", wantF32, got) + } + + var buf bytes.Buffer + en := msgp.NewWriter(&buf) + allF32.EncodeMsg(en) + en.Flush() + enc = buf.Bytes() + i, _, _ = msgp.ReadIntfBytes(enc) + got = i.(map[string]any) + if !reflect.DeepEqual(got, wantF32) { + t.Errorf("want: %v, got: %v (diff may be types)", wantF32, got) + } + + const f64ok = -10e64 + allF64 := Floats{ + A: f64ok, + B: f32ok, + Slice: []float64{f64ok, f64ok}, + Map: map[string]float64{"a": f64ok}, + F: f64ok, + OE: f64ok, + } + asF64 := float64(f64ok) + wantF64 := map[string]any{"A": asF64, "B": asF32, "F": asF64, "Map": map[string]any{"a": asF64}, "OE": asF64, "Slice": []any{asF64, asF64}} + + enc, err = allF64.MarshalMsg(nil) + if err != nil { + t.Error(err) + } + i, _, _ = msgp.ReadIntfBytes(enc) + got = i.(map[string]any) + if !reflect.DeepEqual(got, wantF64) { + t.Errorf("want: %v, got: %v (diff may be types)", wantF64, got) + } + + buf.Reset() + en = msgp.NewWriter(&buf) + allF64.EncodeMsg(en) + en.Flush() + enc = buf.Bytes() + i, _, _ = msgp.ReadIntfBytes(enc) + got = i.(map[string]any) + if !reflect.DeepEqual(got, wantF64) { + t.Errorf("want: %v, got: %v (diff may be types)", wantF64, got) + } +} diff --git a/gen/decode.go b/gen/decode.go index ad21173e..cb78c5e3 100644 --- a/gen/decode.go +++ b/gen/decode.go @@ -29,7 +29,8 @@ func (d *decodeGen) needsField() { d.hasfield = true } -func (d *decodeGen) Execute(p Elem) error { +func (d *decodeGen) Execute(p Elem, ctx Context) error { + d.ctx = &ctx p = d.applyall(p) if p == nil { return nil @@ -43,8 +44,6 @@ func (d *decodeGen) Execute(p Elem) error { return nil } - d.ctx = &Context{} - d.p.comment("DecodeMsg implements msgp.Decodable") d.p.printf("\nfunc (%s %s) DecodeMsg(dc *msgp.Reader) (err error) {", p.Varname(), methodReceiver(p)) diff --git a/gen/encode.go b/gen/encode.go index af83e456..800c4b19 100644 --- a/gen/encode.go +++ b/gen/encode.go @@ -28,6 +28,10 @@ func (e *encodeGen) Apply(dirs []string) error { } func (e *encodeGen) writeAndCheck(typ string, argfmt string, arg interface{}) { + if e.ctx.compFloats && typ == "Float64" { + typ = "Float" + } + e.p.printf("\nerr = en.Write%s(%s)", typ, fmt.Sprintf(argfmt, arg)) e.p.wrapErrCheck(e.ctx.ArgsStr()) } @@ -47,7 +51,8 @@ func (e *encodeGen) Fuse(b []byte) { } } -func (e *encodeGen) Execute(p Elem) error { +func (e *encodeGen) Execute(p Elem, ctx Context) error { + e.ctx = &ctx if !e.p.ok() { return e.p.err } @@ -59,8 +64,6 @@ func (e *encodeGen) Execute(p Elem) error { return nil } - e.ctx = &Context{} - e.p.comment("EncodeMsg implements msgp.Encodable") rcv := imutMethodReceiver(p) ogVar := p.Varname() diff --git a/gen/marshal.go b/gen/marshal.go index 5b94ff39..b58fd633 100644 --- a/gen/marshal.go +++ b/gen/marshal.go @@ -27,7 +27,8 @@ func (m *marshalGen) Apply(dirs []string) error { return nil } -func (m *marshalGen) Execute(p Elem) error { +func (m *marshalGen) Execute(p Elem, ctx Context) error { + m.ctx = &ctx if !m.p.ok() { return m.p.err } @@ -39,8 +40,6 @@ func (m *marshalGen) Execute(p Elem) error { return nil } - m.ctx = &Context{} - m.p.comment("MarshalMsg implements msgp.Marshaler") // save the vname before @@ -64,6 +63,9 @@ func (m *marshalGen) Execute(p Elem) error { } func (m *marshalGen) rawAppend(typ string, argfmt string, arg interface{}) { + if m.ctx.compFloats && typ == "Float64" { + typ = "Float" + } m.p.printf("\no = msgp.Append%s(o, %s)", typ, fmt.Sprintf(argfmt, arg)) } diff --git a/gen/size.go b/gen/size.go index 3798532f..d580b104 100644 --- a/gen/size.go +++ b/gen/size.go @@ -69,7 +69,8 @@ func (s *sizeGen) addConstant(sz string) { panic("unknown size state") } -func (s *sizeGen) Execute(p Elem) error { +func (s *sizeGen) Execute(p Elem, ctx Context) error { + s.ctx = &ctx if !s.p.ok() { return s.p.err } @@ -81,7 +82,6 @@ func (s *sizeGen) Execute(p Elem) error { return nil } - s.ctx = &Context{} s.ctx.PushString(p.TypeName()) s.p.comment("Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message") diff --git a/gen/spec.go b/gen/spec.go index bd57743c..c0bccfd2 100644 --- a/gen/spec.go +++ b/gen/spec.go @@ -75,7 +75,8 @@ const ( ) type Printer struct { - gens []generator + gens []generator + CompactFloats bool } func NewPrinter(m Method, out io.Writer, tests io.Writer) *Printer { @@ -144,7 +145,7 @@ func (p *Printer) Print(e Elem) error { // collisions between idents created during SetVarname and idents created during Print, // hence the separate prefixes. resetIdent("zb") - err := g.Execute(e) + err := g.Execute(e, Context{compFloats: p.CompactFloats}) resetIdent("za") if err != nil { @@ -171,7 +172,8 @@ func (c contextVar) Arg() string { } type Context struct { - path []contextItem + path []contextItem + compFloats bool } func (c *Context) PushString(s string) { @@ -202,7 +204,7 @@ func (c *Context) ArgsStr() string { type generator interface { Method() Method Add(p TransformPass) - Execute(Elem) error // execute writes the method for the provided object. + Execute(Elem, Context) error // execute writes the method for the provided object. } type passes []TransformPass diff --git a/gen/testgen.go b/gen/testgen.go index 0ed81245..d8e85b58 100644 --- a/gen/testgen.go +++ b/gen/testgen.go @@ -26,7 +26,7 @@ type mtestGen struct { w io.Writer } -func (m *mtestGen) Execute(p Elem) error { +func (m *mtestGen) Execute(p Elem, _ Context) error { p = m.applyall(p) if p != nil && IsPrintable(p) { switch p.(type) { @@ -48,7 +48,7 @@ func etest(w io.Writer) *etestGen { return &etestGen{w: w} } -func (e *etestGen) Execute(p Elem) error { +func (e *etestGen) Execute(p Elem, _ Context) error { p = e.applyall(p) if p != nil && IsPrintable(p) { switch p.(type) { diff --git a/gen/unmarshal.go b/gen/unmarshal.go index 75f8a467..05a50597 100644 --- a/gen/unmarshal.go +++ b/gen/unmarshal.go @@ -28,8 +28,9 @@ func (u *unmarshalGen) needsField() { u.hasfield = true } -func (u *unmarshalGen) Execute(p Elem) error { +func (u *unmarshalGen) Execute(p Elem, ctx Context) error { u.hasfield = false + u.ctx = &ctx if !u.p.ok() { return u.p.err } @@ -41,8 +42,6 @@ func (u *unmarshalGen) Execute(p Elem) error { return nil } - u.ctx = &Context{} - u.p.comment("UnmarshalMsg implements msgp.Unmarshaler") u.p.printf("\nfunc (%s %s) UnmarshalMsg(bts []byte) (o []byte, err error) {", p.Varname(), methodReceiver(p)) diff --git a/msgp/write.go b/msgp/write.go index ec2f6f52..e1b46a18 100644 --- a/msgp/write.go +++ b/msgp/write.go @@ -346,6 +346,16 @@ func (mw *Writer) WriteNil() error { return mw.push(mnil) } +// WriteFloat writes a float to the writer as either float64 +// or float32 when it represents the exact same value +func (mw *Writer) WriteFloat(f float64) error { + f32 := float32(f) + if float64(f32) == f { + return mw.prefix32(mfloat32, math.Float32bits(f32)) + } + return mw.prefix64(mfloat64, math.Float64bits(f)) +} + // WriteFloat64 writes a float64 to the writer func (mw *Writer) WriteFloat64(f float64) error { return mw.prefix64(mfloat64, math.Float64bits(f)) diff --git a/msgp/write_bytes.go b/msgp/write_bytes.go index 12606cc2..fb7f7db5 100644 --- a/msgp/write_bytes.go +++ b/msgp/write_bytes.go @@ -60,6 +60,16 @@ func AppendArrayHeader(b []byte, sz uint32) []byte { // AppendNil appends a 'nil' byte to the slice func AppendNil(b []byte) []byte { return append(b, mnil) } +// AppendFloat appends a float to the slice as either float64 +// or float32 when it represents the exact same value +func AppendFloat(b []byte, f float64) []byte { + f32 := float32(f) + if float64(f32) == f { + return AppendFloat32(b, f32) + } + return AppendFloat64(b, f) +} + // AppendFloat64 appends a float64 to the slice func AppendFloat64(b []byte, f float64) []byte { o, n := ensure(b, Float64Size) diff --git a/msgp/write_bytes_test.go b/msgp/write_bytes_test.go index 93e13ff6..e2763279 100644 --- a/msgp/write_bytes_test.go +++ b/msgp/write_bytes_test.go @@ -3,6 +3,7 @@ package msgp import ( "bytes" "math" + "math/rand" "reflect" "strings" "testing" @@ -134,6 +135,74 @@ func TestAppendNil(t *testing.T) { } } +func TestAppendFloat(t *testing.T) { + rng := rand.New(rand.NewSource(0)) + const n = 1e7 + src := make([]float64, n) + for i := range src { + // ~50% full float64, 50% converted from float32. + if rng.Uint32()&1 == 1 { + src[i] = rng.NormFloat64() + } else { + src[i] = float64(math.MaxFloat32 * (0.5 - rng.Float32())) + } + } + + var buf bytes.Buffer + en := NewWriter(&buf) + + var bts []byte + for _, f := range src { + en.WriteFloat(f) + bts = AppendFloat(bts, f) + } + en.Flush() + if buf.Len() != len(bts) { + t.Errorf("encoder wrote %d; append wrote %d bytes", buf.Len(), len(bts)) + } + t.Logf("%f bytes/value", float64(buf.Len())/n) + a, b := bts, buf.Bytes() + for i := range a { + if a[i] != b[i] { + t.Errorf("mismatch at byte %d, %d != %d", i, a[i], b[i]) + break + } + } + + for i, want := range src { + var got float64 + var err error + got, a, err = ReadFloat64Bytes(a) + if err != nil { + t.Fatal(err) + } + if want != got { + t.Errorf("value #%d: want %v; got %v", i, want, got) + } + } +} + +func BenchmarkAppendFloat(b *testing.B) { + rng := rand.New(rand.NewSource(0)) + const n = 1 << 16 + src := make([]float64, n) + for i := range src { + // ~50% full float64, 50% converted from float32. + if rng.Uint32()&1 == 1 { + src[i] = rng.NormFloat64() + } else { + src[i] = float64(math.MaxFloat32 * (0.5 - rng.Float32())) + } + } + buf := make([]byte, 0, 9) + b.SetBytes(8) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + AppendFloat64(buf, src[i&(n-1)]) + } +} + func TestAppendFloat64(t *testing.T) { f := float64(3.14159) var buf bytes.Buffer diff --git a/parse/directives.go b/parse/directives.go index 620589d6..2ea1be09 100644 --- a/parse/directives.go +++ b/parse/directives.go @@ -22,10 +22,11 @@ type passDirective func(gen.Method, []string, *gen.Printer) error // to add a directive, define a func([]string, *FileSet) error // and then add it to this list. var directives = map[string]directive{ - "shim": applyShim, - "replace": replace, - "ignore": ignore, - "tuple": astuple, + "shim": applyShim, + "replace": replace, + "ignore": ignore, + "tuple": astuple, + "compactfloats": compactfloats, } // map of all recognized directives which will be applied @@ -186,3 +187,9 @@ func pointer(text []string, f *FileSet) error { f.pointerRcv = true return nil } + +//msgp:compactfloats +func compactfloats(text []string, f *FileSet) error { + f.CompactFloats = true + return nil +} diff --git a/parse/getast.go b/parse/getast.go index ef9782f0..7d6cebe9 100644 --- a/parse/getast.go +++ b/parse/getast.go @@ -16,13 +16,14 @@ import ( // A FileSet is the in-memory representation of a // parsed file. type FileSet struct { - Package string // package name - Specs map[string]ast.Expr // type specs in file - Identities map[string]gen.Elem // processed from specs - Directives []string // raw preprocessor directives - Imports []*ast.ImportSpec // imports - tagName string // tag to read field names from - pointerRcv bool // generate with pointer receivers. + Package string // package name + Specs map[string]ast.Expr // type specs in file + Identities map[string]gen.Elem // processed from specs + Directives []string // raw preprocessor directives + Imports []*ast.ImportSpec // imports + CompactFloats bool // Use smaller floats when feasible. + tagName string // tag to read field names from + pointerRcv bool // generate with pointer receivers. } // File parses a file at the relative path @@ -269,6 +270,7 @@ loop: warnf("empty directive: %q\n", d) } } + p.CompactFloats = f.CompactFloats } func (f *FileSet) PrintTo(p *gen.Printer) error {