Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip reflection for most types #369

Merged
merged 1 commit into from
Jun 12, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 16 additions & 160 deletions msg.go
Original file line number Diff line number Diff line change
Expand Up @@ -587,35 +587,6 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
return lenmsg, err
}
}
case `dns:"txt"`:
if txtTmp == nil {
txtTmp = make([]byte, 256*4+1)
}
off, err = packTxt(fv.Interface().([]string), msg, off, txtTmp)
if err != nil {
return lenmsg, err
}
case `dns:"opt"`: // edns
for j := 0; j < val.Field(i).Len(); j++ {
element := val.Field(i).Index(j).Interface()
b, err := element.(EDNS0).pack()
if err != nil || off+3 > lenmsg {
return lenmsg, &Error{err: "overflow packing opt"}
}
// Option code
binary.BigEndian.PutUint16(msg[off:], element.(EDNS0).Option())
// Length
binary.BigEndian.PutUint16(msg[off+2:], uint16(len(b)))
off += 4
if off+len(b) > lenmsg {
copy(msg[off:], b)
off = lenmsg
continue
}
// Actual data
copy(msg[off:off+len(b)], b)
off += len(b)
}
case `dns:"a"`:
if val.Type().String() == "dns.IPSECKEY" {
// Field(2) is GatewayType, must be 1
Expand Down Expand Up @@ -898,108 +869,6 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
servers = append(servers, s)
}
fv.Set(reflect.ValueOf(servers))
case `dns:"txt"`:
if off == lenmsg {
break
}
var txt []string
txt, off, err = unpackTxt(msg, off)
if err != nil {
return lenmsg, err
}
fv.Set(reflect.ValueOf(txt))
case `dns:"opt"`: // edns0
if off == lenmsg {
// This is an EDNS0 (OPT Record) with no rdata
// We can safely return here.
break
}
var edns []EDNS0
Option:
code := uint16(0)
if off+4 > lenmsg {
return lenmsg, &Error{err: "overflow unpacking opt"}
}
code = binary.BigEndian.Uint16(msg[off:])
off += 2
optlen := binary.BigEndian.Uint16(msg[off:])
off1 := off + 2
if off1+int(optlen) > lenmsg {
return lenmsg, &Error{err: "overflow unpacking opt"}
}
switch code {
case EDNS0NSID:
e := new(EDNS0_NSID)
if err := e.unpack(msg[off1 : off1+int(optlen)]); err != nil {
return lenmsg, err
}
edns = append(edns, e)
off = off1 + int(optlen)
case EDNS0SUBNET, EDNS0SUBNETDRAFT:
e := new(EDNS0_SUBNET)
if err := e.unpack(msg[off1 : off1+int(optlen)]); err != nil {
return lenmsg, err
}
edns = append(edns, e)
off = off1 + int(optlen)
if code == EDNS0SUBNETDRAFT {
e.DraftOption = true
}
case EDNS0COOKIE:
e := new(EDNS0_COOKIE)
if err := e.unpack(msg[off1 : off1+int(optlen)]); err != nil {
return lenmsg, err
}
edns = append(edns, e)
off = off1 + int(optlen)
case EDNS0UL:
e := new(EDNS0_UL)
if err := e.unpack(msg[off1 : off1+int(optlen)]); err != nil {
return lenmsg, err
}
edns = append(edns, e)
off = off1 + int(optlen)
case EDNS0LLQ:
e := new(EDNS0_LLQ)
if err := e.unpack(msg[off1 : off1+int(optlen)]); err != nil {
return lenmsg, err
}
edns = append(edns, e)
off = off1 + int(optlen)
case EDNS0DAU:
e := new(EDNS0_DAU)
if err := e.unpack(msg[off1 : off1+int(optlen)]); err != nil {
return lenmsg, err
}
edns = append(edns, e)
off = off1 + int(optlen)
case EDNS0DHU:
e := new(EDNS0_DHU)
if err := e.unpack(msg[off1 : off1+int(optlen)]); err != nil {
return lenmsg, err
}
edns = append(edns, e)
off = off1 + int(optlen)
case EDNS0N3U:
e := new(EDNS0_N3U)
if err := e.unpack(msg[off1 : off1+int(optlen)]); err != nil {
return lenmsg, err
}
edns = append(edns, e)
off = off1 + int(optlen)
default:
e := new(EDNS0_LOCAL)
e.Code = code
if err := e.unpack(msg[off1 : off1+int(optlen)]); err != nil {
return lenmsg, err
}
edns = append(edns, e)
off = off1 + int(optlen)
}
if off < lenmsg {
goto Option
}
fv.Set(reflect.ValueOf(edns))
case `dns:"a"`:
if val.Type().String() == "dns.IPSECKEY" {
// Field(2) is GatewayType, must be 1
Expand Down Expand Up @@ -1295,11 +1164,10 @@ func PackRR(rr RR, msg []byte, off int, compression map[string]int, compress boo
return len(msg), &Error{err: "nil rr"}
}

_, ok := typeToUnpack[rr.Header().Rrtype]
_, ok := blacklist[rr.Header().Rrtype]
switch ok {
case true:
case false:
off1, err = rr.pack(msg, off, compression, compress)
// TODO(miek): revert the logic and make a blacklist for types that still use reflection. Kill typeToUnpack.
default:
off1, err = packStructCompress(rr, msg, off, compression, compress)
}
Expand All @@ -1321,14 +1189,17 @@ func UnpackRR(msg []byte, off int) (rr RR, off1 int, err error) {
}
end := off + int(h.Rdlength)

fn, ok := typeToUnpack[h.Rrtype]
_, ok := blacklist[h.Rrtype]
switch ok {
case true:
case false:
// Shortcut reflection.
rr, off, err = fn(h, msg, off)
if fn, known := typeToUnpack[h.Rrtype]; !known {
rr, off, err = unpackRFC3597(h, msg, off)
} else {
rr, off, err = fn(h, msg, off)
}
default:
mk, known := TypeToRR[h.Rrtype]
if !known {
if mk, known := TypeToRR[h.Rrtype]; !known {
rr = new(RFC3597)
} else {
rr = mk()
Expand Down Expand Up @@ -1981,25 +1852,10 @@ func unpackMsgHdr(msg []byte, off int) (Header, int, error) {
return dh, off, err
}

// Which types have type specific unpack functions.
var typeToUnpack = map[uint16]func(RR_Header, []byte, int) (RR, int, error){
TypeAAAA: unpackAAAA,
TypeA: unpackA,
TypeCNAME: unpackCNAME,
TypeDNAME: unpackDNAME,
TypeL32: unpackL32,
TypeLOC: unpackLOC,
TypeMB: unpackMB,
TypeMD: unpackMD,
TypeMF: unpackMF,
TypeMG: unpackMG,
TypeMR: unpackMR,
TypeMX: unpackMX,
TypeNID: unpackNID,
TypeNS: unpackNS,
TypePTR: unpackPTR,
TypeRP: unpackRP,
TypeSRV: unpackSRV,
TypeHINFO: unpackHINFO,
TypeDNSKEY: unpackDNSKEY,
// Which types do no work reflectionless yet.
var blacklist = map[uint16]bool{
TypeHIP: true,
TypeIPSECKEY: true,
TypeNSEC3: true,
TypeTSIG: true,
}
71 changes: 42 additions & 29 deletions msg_generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,11 @@ import (
"os"
)

// All RR pack and unpack functions should be generated, currently RR that present some
// problems
// * NSEC/NSEC3 - type bitmap
// * TXT/SPF - string slice
// * URI - weird octet thing there
// * NSEC3/TSIG - size hex
// * OPT RR - EDNS0 parsing - needs to some looking at
// All RR pack and unpack functions should be generated, currently RR that present some problems
// * NSEC3 - size hex
// * TSIG - size hex
// * HIP - uses "hex", but is actually size-hex - might drop size-hex?
// * Z
// * NINFO
// * PrivateRR
// * IPSECKEY -

var packageHdr = `
// *** DO NOT MODIFY ***
Expand Down Expand Up @@ -90,10 +84,7 @@ func main() {
fmt.Fprint(b, "// pack*() functions\n\n")
for _, name := range namedTypes {
o := scope.Lookup(name)
st, isEmbedded := getTypeStruct(o.Type(), scope)
if isEmbedded {
continue
}
st, _ := getTypeStruct(o.Type(), scope)

fmt.Fprintf(b, "func (rr *%s) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) {\n", name)
fmt.Fprint(b, `off, err := rr.Hdr.pack(msg, off, compression, compress)
Expand All @@ -113,19 +104,23 @@ return off, err

if _, ok := st.Field(i).Type().(*types.Slice); ok {
switch st.Tag(i) {
case `dns:"-"`:
// ignored
case `dns:"-"`: // ignored
case `dns:"txt"`:
o("off, err = packStringTxt(rr.%s, msg, off)\n")
case `dns:"opt"`:
o("off, err = packDataOpt(rr.%s, msg, off)\n")
case `dns:"nsec"`:
o("off, err = packDataNsec(rr.%s, msg, off)\n")
case `dns:"domain-name"`:
o("off, err = packDataDomainNames(rr.%s, msg, off, compression, compress)\n")
default:
//log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
}
continue
}

switch st.Tag(i) {
case `dns:"-"`:
// ignored
case `dns:"-"`: // ignored
case `dns:"cdomain-name"`:
fallthrough
case `dns:"domain-name"`:
Expand All @@ -142,6 +137,10 @@ return off, err
o("off, err = packStringBase32(rr.%s, msg, off)\n")
case `dns:"base64"`:
o("off, err = packStringBase64(rr.%s, msg, off)\n")
case `dns:"hex"`:
o("off, err = packStringHex(rr.%s, msg, off)\n")
case `dns:"octet"`:
o("off, err = packStringOctet(rr.%s, msg, off)\n")
case "":
switch st.Field(i).Type().(*types.Basic).Kind() {
case types.Uint8:
Expand Down Expand Up @@ -169,14 +168,11 @@ return off, err
fmt.Fprint(b, "// unpack*() functions\n\n")
for _, name := range namedTypes {
o := scope.Lookup(name)
st, isEmbedded := getTypeStruct(o.Type(), scope)
if isEmbedded {
continue
}
st, _ := getTypeStruct(o.Type(), scope)

fmt.Fprintf(b, "func unpack%s(h RR_Header, msg []byte, off int) (RR, int, error) {\n", name)
fmt.Fprint(b, `if noRdata(h) {
return nil, off, nil
return &h, off, nil
}
var err error
rdStart := off
Expand All @@ -196,19 +192,23 @@ return rr, off, err

if _, ok := st.Field(i).Type().(*types.Slice); ok {
switch st.Tag(i) {
case `dns:"-"`:
// ignored
case `dns:"-"`: // ignored
case `dns:"txt"`:
o("rr.%s, off, err = unpackStringTxt(msg, off)\n")
case `dns:"opt"`:
o("rr.%s, off, err = unpackDataOpt(msg, off)\n")
case `dns:"nsec"`:
o("rr.%s, off, err = unpackDataNsec(msg, off)\n")
case `dns:"domain-name"`:
o("rr.%s, off, err = unpackDataDomainNames(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
default:
//log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
}
continue
}

switch st.Tag(i) {
case `dns:"-"`:
// ignored
case `dns:"-"`: // ignored
case `dns:"cdomain-name"`:
fallthrough
case `dns:"domain-name"`:
Expand All @@ -225,6 +225,10 @@ return rr, off, err
o("rr.%s, off, err = unpackStringBase32(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
case `dns:"base64"`:
o("rr.%s, off, err = unpackStringBase64(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
case `dns:"hex"`:
o("rr.%s, off, err = unpackStringHex(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
case `dns:"octet"`:
o("rr.%s, off, err = unpackStringOctet(msg, off)\n")
case "":
switch st.Field(i).Type().(*types.Basic).Kind() {
case types.Uint8:
Expand Down Expand Up @@ -253,6 +257,15 @@ return rr, off, nil
}
fmt.Fprintf(b, "return rr, off, err }\n\n")
}
// Generate typeToUnpack map
fmt.Fprintln(b, "var typeToUnpack = map[uint16]func(RR_Header, []byte, int) (RR, int, error){")
for _, name := range namedTypes {
if name == "RFC3597" {
continue
}
fmt.Fprintf(b, "Type%s: unpack%s,\n", name, name)
}
fmt.Fprintln(b, "}\n")

// gofmt
res, err := format.Source(b.Bytes())
Expand Down
Loading