From e5dd9020f4e59b48120cfd6edf7121ea0e97ec31 Mon Sep 17 00:00:00 2001 From: garmr Date: Fri, 13 Oct 2023 16:27:52 -0700 Subject: [PATCH] fix newstrategy test in geneva_test.go, fix linter errors --- actions/actions.go | 8 +- actions/duplicate_action.go | 7 +- actions/fragment_packet.go | 22 +++-- actions/tamper_action.go | 163 ++++++++++++++++++++-------------- actions/tamper_action_test.go | 2 + common/packet.go | 8 +- geneva_test.go | 4 +- 7 files changed, 134 insertions(+), 80 deletions(-) diff --git a/actions/actions.go b/actions/actions.go index 4e9edcb..01bd3d4 100644 --- a/actions/actions.go +++ b/actions/actions.go @@ -4,17 +4,21 @@ package actions import ( + "errors" "fmt" + "github.com/google/gopacket" + "github.com/getlantern/geneva/internal" "github.com/getlantern/geneva/internal/scanner" "github.com/getlantern/geneva/triggers" - "github.com/google/gopacket" // gopacket best practice says import this, too. _ "github.com/google/gopacket/layers" ) +var ErrInvalidAction = errors.New("invalid action") + // ActionTree represents a Geneva (trigger, action) pair. // // Technically, Geneva uses the term "action tree" to refer to the tree of actions in the tuple @@ -122,5 +126,5 @@ func ParseAction(s *scanner.Scanner) (Action, error) { return DefaultSendAction, nil } - return nil, fmt.Errorf("invalid action at %d", s.Pos()) + return nil, fmt.Errorf("%w at %d", ErrInvalidAction, s.Pos()) } diff --git a/actions/duplicate_action.go b/actions/duplicate_action.go index 490f7ed..7512144 100644 --- a/actions/duplicate_action.go +++ b/actions/duplicate_action.go @@ -4,9 +4,10 @@ import ( "errors" "fmt" + "github.com/google/gopacket" + "github.com/getlantern/geneva/internal" "github.com/getlantern/geneva/internal/scanner" - "github.com/google/gopacket" ) // DuplicateAction is a Geneva action that duplicates a packet and applies separate action trees to @@ -114,6 +115,10 @@ func ParseDuplicateAction(s *scanner.Scanner) (Action, error) { } if action.Left, err = ParseAction(s); err != nil { + if !errors.Is(err, ErrInvalidAction) { + return nil, err + } + if c, err2 := s.Peek(); err2 == nil && c == ',' { action.Left = &SendAction{} } else { diff --git a/actions/fragment_packet.go b/actions/fragment_packet.go index f8cb34f..db2d429 100644 --- a/actions/fragment_packet.go +++ b/actions/fragment_packet.go @@ -145,7 +145,9 @@ func fragmentTCPSegment(packet gopacket.Packet, fragSize int) ([]gopacket.Packet ipHdrLen := uint16(ipv4Buf[0]&0x0f) * 4 first := gopacket.NewPacket(buf, packet.Layers()[0].LayerType(), gopacket.NoCopy) - updateChecksums(first) + if err := updateChecksums(first); err != nil { + return nil, err + } // create the second fragment. f2Len := headersLen + tcpPayloadLen - fragSize @@ -166,7 +168,9 @@ func fragmentTCPSegment(packet gopacket.Packet, fragSize int) ([]gopacket.Packet binary.BigEndian.PutUint32(tcp[4:], seqNum) second := gopacket.NewPacket(buf, packet.Layers()[0].LayerType(), gopacket.NoCopy) - updateChecksums(second) + if err := updateChecksums(second); err != nil { + return nil, err + } return []gopacket.Packet{first, second}, nil } @@ -260,14 +264,16 @@ func FragmentIPPacket(packet gopacket.Packet, fragSize int) ([]gopacket.Packet, return []gopacket.Packet{first, second}, nil } -func updateChecksums(packet gopacket.Packet) { - if ipv4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4); ipv4 != nil { +func updateChecksums(packet gopacket.Packet) error { + if ipv4, _ := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4); ipv4 != nil { common.UpdateIPv4Checksum(ipv4) } - if tcp := packet.Layer(layers.LayerTypeTCP).(*layers.TCP); tcp != nil { - common.UpdateTCPChecksum(tcp) + if tcp, _ := packet.Layer(layers.LayerTypeTCP).(*layers.TCP); tcp != nil { + return common.UpdateTCPChecksum(tcp) } + + return nil } // VerifyIPv4Checksum verifies whether an IPv4 header's checksum field is correct. @@ -375,6 +381,10 @@ func ParseFragmentAction(s *scanner.Scanner) (Action, error) { } if action.FirstFragmentAction, err = ParseAction(s); err != nil { + if !errors.Is(err, ErrInvalidAction) { + return nil, err + } + if c, err2 := s.Peek(); err2 == nil && c == ',' { action.FirstFragmentAction = &SendAction{} } else { diff --git a/actions/tamper_action.go b/actions/tamper_action.go index 8944256..0b29df2 100644 --- a/actions/tamper_action.go +++ b/actions/tamper_action.go @@ -1,3 +1,6 @@ +// Package actions describes the actions that can be applied to a given packet. +// +//nolint:depguard,dupword package actions import ( @@ -24,6 +27,13 @@ const ( TamperCorrupt ) +//nolint:revive +var ( + ErrInvalidTamperMode = errors.New("invalid tamper mode") + ErrInvalidTamperRule = errors.New("invalid tamper rule") + ErrUDPNotSupported = errors.New("UDP tamper action not currently supported") +) + // TamperMode describes the way that the "tamper" action can manipulate a packet. type TamperMode int @@ -80,19 +90,19 @@ func (a *TamperAction) String() string { // If the string is malformed, an error will be returned instead. func ParseTamperAction(s *scanner.Scanner) (Action, error) { if _, err := s.Expect("tamper{"); err != nil { - return nil, fmt.Errorf("invalid tamper rule: %w", err) + return nil, fmt.Errorf("%s: %w", ErrInvalidTamperRule, err) } str, err := s.Until('}') if err != nil { - return nil, fmt.Errorf("invalid tamper rule: %w", err) + return nil, fmt.Errorf("%s: %w", ErrInvalidTamperRule, err) } _, _ = s.Pop() fields := strings.Split(str, ":") if len(fields) < 3 || len(fields) > 4 { - return nil, fmt.Errorf("invalid fields for tamper rule: %s", str) + return nil, fmt.Errorf("%s: invalid field, %w", ErrInvalidTamperRule, err) } var ( @@ -111,7 +121,8 @@ func ParseTamperAction(s *scanner.Scanner) (Action, error) { mode = TamperCorrupt default: return nil, fmt.Errorf( - "invalid tamper mode: %q must be either 'replace' or 'corrupt'", + "%w: %q must be either 'replace' or 'corrupt'", + ErrInvalidTamperMode, fields[2], ) } @@ -123,37 +134,46 @@ func ParseTamperAction(s *scanner.Scanner) (Action, error) { NewValue: newValue, } - if _, err := s.Expect("("); err == nil { - if tamperAction.Action, err = ParseAction(s); err != nil { - if c, err2 := s.Peek(); err2 == nil && c == ')' { - tamperAction.Action = &SendAction{} - } else { - return nil, fmt.Errorf("invalid action for tamper rule: %w", err) - } + if _, err = s.Expect("("); err != nil { + tamperAction.Action = &SendAction{} + return newTamperAction(tamperAction) + } + + if tamperAction.Action, err = ParseAction(s); err != nil { + if !errors.Is(err, ErrInvalidAction) { + return nil, err } - if _, err = s.Expect(","); err == nil { - if !s.FindToken(")", true) { - return nil, fmt.Errorf("tamper rules can only have one action") - } + if c, err2 := s.Peek(); err2 == nil && c == ')' { + tamperAction.Action = &SendAction{} + } else { + return nil, fmt.Errorf("%s: invalid action, %w", ErrInvalidTamperRule, err) } + } - if _, err := s.Expect(")"); err != nil { - return nil, fmt.Errorf("unexpected token in tamper rule: %w", err) + if _, err = s.Expect(","); err == nil { + if !s.FindToken(")", true) { + return nil, fmt.Errorf("%w: only one action is allowed", ErrInvalidTamperRule) } - } else { - tamperAction.Action = &SendAction{} } - switch proto { + if _, err = s.Expect(")"); err != nil { + return nil, fmt.Errorf("%s: unexpected token: %w", ErrInvalidTamperRule, err) + } + + return newTamperAction(tamperAction) +} + +func newTamperAction(ta TamperAction) (Action, error) { + switch ta.Proto { case "IP": - return NewIPv4TamperAction(tamperAction) + return NewIPv4TamperAction(ta) case "TCP": - return NewTCPTamperAction(tamperAction) + return NewTCPTamperAction(ta) case "UDP": - return nil, fmt.Errorf("UDP tamper action not currently supported") + return nil, ErrUDPNotSupported default: - return nil, fmt.Errorf("invalid tamper rule: %q is not a recognized protocol", proto) + return nil, fmt.Errorf("%w: %q is not a recognized protocol", ErrInvalidTamperRule, ta.Proto) } } @@ -164,6 +184,7 @@ func ParseTamperAction(s *scanner.Scanner) (Action, error) { // TCPField is a TCP field that can be modified by a TCPTamperAction. type TCPField uint8 +//nolint:revive const ( // supported TCP options. The other options are apparently obsolete and not used. TCPOptionEol = layers.TCPOptionKindEndList @@ -183,15 +204,16 @@ const ( // putting fields after options so that we can use the gopacket.TCPOptionKind constants for options. // this lets us use the same map for both fields and options and also directly compare // tcpTamperAction.field == TCPOption when iterating over tcpPacket.Options. - TCPFieldSrcPort = 9 - TCPFieldDstPort = 10 - TCPFieldSeq = 11 - TCPFieldAck = 12 - TCPFieldDataOff = 13 - TCPFieldFlags = 15 - TCPFieldWindow = 16 - TCPFieldUrgent = 17 - TCPLoad = 18 + TCPFieldSrcPort = 9 + TCPFieldDstPort = 10 + TCPFieldSeq = 11 + TCPFieldAck = 12 + TCPFieldDataOff = 13 + TCPFieldFlags = 15 + TCPFieldWindow = 16 + TCPFieldUrgent = 17 + TCPFieldChecksum = 18 + TCPLoad = 20 // TCP flag string representations for tamper rules. TCPFlagFin = "f" @@ -217,6 +239,7 @@ var ( "flags": TCPFieldFlags, "window": TCPFieldWindow, "urgent": TCPFieldUrgent, + "chksum": TCPFieldChecksum, "options-eol": TCPOptionEol, "options-nop": TCPOptionNop, "options-mss": TCPOptionMss, @@ -259,12 +282,12 @@ type TCPTamperAction struct { func NewTCPTamperAction(ta TamperAction) (*TCPTamperAction, error) { field, ok := tcpFields[ta.Field] if !ok { - return nil, fmt.Errorf("invalid tamper rule: %q is not a recognized TCP field", ta.Field) + return nil, fmt.Errorf("%w: %q is not a recognized TCP field", ErrInvalidTamperRule, ta.Field) } switch ta.Mode { case TamperCorrupt: - r := rand.New(rand.NewSource(time.Now().UnixNano())) + r := rand.New(rand.NewSource(time.Now().UnixNano())) //nolint:gosec return &TCPTamperAction{ TamperAction: ta, @@ -294,7 +317,12 @@ func NewTCPTamperAction(ta TamperAction) (*TCPTamperAction, error) { default: val, err := strconv.ParseUint(ta.NewValue, 10, 32) if err != nil { - return nil, fmt.Errorf("invalid tamper rule: %q is not a valid value for field %q", ta.NewValue, ta.Field) + return nil, fmt.Errorf( + "%w: %q is not a valid value for field %q", + ErrInvalidTamperRule, + ta.NewValue, + ta.Field, + ) } gen.vUint = uint32(val) @@ -307,27 +335,30 @@ func NewTCPTamperAction(ta TamperAction) (*TCPTamperAction, error) { }, nil } - return nil, fmt.Errorf("invalid tamper rule: %q is not a valid tamper mode for TCP", ta.Mode) + return nil, fmt.Errorf("%w: %q is not a valid tamper mode for TCP", ErrInvalidTamperRule, ta.Mode) } // Apply applies the tamper action to the given packet. func (a *TCPTamperAction) Apply(packet gopacket.Packet) ([]gopacket.Packet, error) { tcp := packet.Layer(layers.LayerTypeTCP).(*layers.TCP) if tcp == nil { - return nil, errors.New("packet does not have a TCP layer") + return nil, errors.New("packet does not have a TCP layer") //nolint:goerr113 } tamperTCP(tcp, a.field, a.valueGen) // if tampering with TCP options, we need to update the data offset and checksum if strings.HasPrefix(a.Field, "options") { - updateTCPDataOffAndChksum(tcp) + if err := updateTCPDataOffAndChksum(tcp); err != nil { + return nil, fmt.Errorf("failed to update checksum: %w", err) + } + if ip := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4); ip != nil { updateIPv4LengthAndChksum(ip) } } - return a.Action.Apply(packet) + return a.Action.Apply(packet) //nolint:wrapcheck } // tamperTCP modifies the given TCP field using the given value generator. @@ -347,6 +378,8 @@ func tamperTCP(tcp *layers.TCP, field TCPField, valueGen tamperValueGen) { tcp.Window = uint16(valueGen.uint(16)) case TCPFieldUrgent: tcp.Urgent = uint16(valueGen.uint(16)) + case TCPFieldChecksum: + tcp.Checksum = uint16(valueGen.uint(16)) case TCPFieldFlags: setTCPFlags(tcp, uint16(valueGen.uint(16))) case TCPLoad: @@ -354,6 +387,7 @@ func tamperTCP(tcp *layers.TCP, field TCPField, valueGen tamperValueGen) { default: // find option in TCP header var opt *layers.TCPOption + for i, o := range tcp.Options { if field == TCPField(o.OptionType) { opt = &tcp.Options[i] @@ -383,7 +417,7 @@ func tamperTCP(tcp *layers.TCP, field TCPField, valueGen tamperValueGen) { // instead of the underlying []byte directly. SerializeTo doesn't write the changes to the raw packet // so we have to copy the formatted bytes back into the packet header. sb := gopacket.NewSerializeBuffer() - tcp.SerializeTo(sb, gopacket.SerializeOptions{}) + tcp.SerializeTo(sb, gopacket.SerializeOptions{}) //nolint:errcheck,gosec tcp.Contents = make([]byte, len(sb.Bytes())) copy(tcp.Contents, sb.Bytes()) } @@ -393,7 +427,7 @@ func tcpFlagsToUint32(flags string) uint32 { flags = strings.ToLower(flags) var f uint32 - for _, c := range flags { + for _, c := range flags { //nolint:wsl switch c { case 'f': // FIN f |= 0x0001 @@ -415,6 +449,7 @@ func tcpFlagsToUint32(flags string) uint32 { f |= 0x0100 } } + return f } @@ -433,14 +468,14 @@ func setTCPFlags(tcp *layers.TCP, flags uint16) { // updateTCPDataOffAndChksum updates the TCP data offset and checksum fields on the TCP struct // and in the raw packet bytes. -func updateTCPDataOffAndChksum(tcp *layers.TCP) { +func updateTCPDataOffAndChksum(tcp *layers.TCP) error { // update data offset headerLen := len(tcp.Contents) tcp.DataOffset = uint8(headerLen / 4) tcp.Contents[12] = tcp.DataOffset << 4 // update checksum. - common.UpdateTCPChecksum(tcp) + return common.UpdateTCPChecksum(tcp) //nolint:wrapcheck } // @@ -450,8 +485,9 @@ func updateTCPDataOffAndChksum(tcp *layers.TCP) { // IPv4Field is an IPv4 field that can be modified by an IPv4TamperAction. type IPv4Field uint8 +//nolint:revive const ( - // supported IPv4 fields + // supported IPv4 fields. IPv4FieldSrcIP = iota IPv4FieldDstIP IPv4FieldVersion @@ -473,7 +509,7 @@ var ipv4Fields = map[string]IPv4Field{ "verion": IPv4FieldVersion, "ihl": IPv4FieldIHL, "tos": IPv4FieldTOS, - "length": IPv4FieldLength, + "len": IPv4FieldLength, "id": IPv4FieldID, // // I don't know what the flags will look like in a tamper rule @@ -501,12 +537,12 @@ type IPv4TamperAction struct { func NewIPv4TamperAction(ta TamperAction) (*IPv4TamperAction, error) { field, ok := ipv4Fields[ta.Field] if !ok { - return nil, fmt.Errorf("invalid tamper rule: %q is not a recognized IPv4 field", ta.Field) + return nil, fmt.Errorf("%w: %q is not a recognized IPv4 field", ErrInvalidTamperRule, ta.Field) } switch ta.Mode { case TamperCorrupt: - r := rand.New(rand.NewSource(time.Now().UnixNano())) + r := rand.New(rand.NewSource(time.Now().UnixNano())) //nolint:gosec return &IPv4TamperAction{ TamperAction: ta, @@ -521,11 +557,11 @@ func NewIPv4TamperAction(ta TamperAction) (*IPv4TamperAction, error) { // parse IP address from NewValue and convert to []byte ip := net.ParseIP(ta.NewValue) if ip == nil { - return nil, fmt.Errorf("invalid tamper rule: %q is not a valid IPv4 address", ta.NewValue) + return nil, fmt.Errorf("%w: %q is not a valid IPv4 address", ErrInvalidTamperRule, ta.NewValue) } if ip.To4() == nil { - return nil, fmt.Errorf("invalid tamper rule: IPv6 is not supported") + return nil, fmt.Errorf("%w: IPv6 is not supported", ErrInvalidTamperRule) } gen.vBytes = ip @@ -535,7 +571,7 @@ func NewIPv4TamperAction(ta TamperAction) (*IPv4TamperAction, error) { // parse uint from NewValue val, err := strconv.ParseUint(ta.NewValue, 10, 32) if err != nil { - return nil, fmt.Errorf("invalid tamper rule: %q is not a valid value for field %q", ta.NewValue, ta.Field) + return nil, fmt.Errorf("%w: %q is not a valid value for field %q", ErrInvalidTamperRule, ta.NewValue, ta.Field) } gen.vUint = uint32(val) @@ -548,24 +584,24 @@ func NewIPv4TamperAction(ta TamperAction) (*IPv4TamperAction, error) { }, nil } - return nil, fmt.Errorf("invalid tamper rule: %q is not a valid tamper mode for IPv4", ta.Mode) + return nil, fmt.Errorf("%w: %q is not a valid tamper mode for IPv4", ErrInvalidTamperRule, ta.Mode) } // Apply applies the tamper action to the given packet. func (a *IPv4TamperAction) Apply(packet gopacket.Packet) ([]gopacket.Packet, error) { ip := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) if ip == nil { - return nil, fmt.Errorf("packet does not have a IPv4 layer") + return nil, errors.New("packet does not have a IPv4 layer") //nolint:goerr113 } tamperIPv4(ip, a.field, a.valueGen) common.UpdateIPv4Checksum(ip) - return a.Action.Apply(packet) + return a.Action.Apply(packet) //nolint:wrapcheck } // tamperIPv4 modifies the given IP field using the given value generator. -func tamperIPv4(ip *layers.IPv4, field IPv4Field, valueGen tamperValueGen) error { +func tamperIPv4(ip *layers.IPv4, field IPv4Field, valueGen tamperValueGen) { switch field { case IPv4FieldSrcIP: ip.SrcIP = valueGen.bytes(4) @@ -580,7 +616,7 @@ func tamperIPv4(ip *layers.IPv4, field IPv4Field, valueGen tamperValueGen) error case IPv4FieldLength: ip.Length = uint16(valueGen.uint(16)) case IPv4FieldFlags: - // TODO: maybe implement this? + // not implemented yet. see comment above. case IPv4FieldFragOffset: ip.FragOffset = uint16(valueGen.uint(16)) case IPv4FieldTTL: @@ -596,25 +632,18 @@ func tamperIPv4(ip *layers.IPv4, field IPv4Field, valueGen tamperValueGen) error // let gopacket handle converting modified packet into []byte again, it's just easier // again copy the bytes back into the packet header sb := gopacket.NewSerializeBuffer() - ip.SerializeTo(sb, gopacket.SerializeOptions{}) + ip.SerializeTo(sb, gopacket.SerializeOptions{}) //nolint:errcheck,gosec ip.Contents = make([]byte, len(sb.Bytes())) copy(ip.Contents, sb.Bytes()) - - return nil } -// updateIPv4LengthAndChksum updates the IPv4 length +// updateIPv4LengthAndChksum updates the IPv4 length. func updateIPv4LengthAndChksum(ip *layers.IPv4) { length := len(ip.Contents) + len(ip.Payload) ip.Length = uint16(length) binary.BigEndian.PutUint16(ip.Contents[2:4], ip.Length) } -// -// UDP Tamper Action -// TODO: implement UDP tamper actions -// - // tamperValueGen is a value generator for tamper actions. type tamperValueGen interface { uint(bitSize int) uint32 @@ -638,6 +667,7 @@ func (g *tamperReplaceGen) bytes(n int) []byte { if n == 0 { return []byte{} } + return append([]byte{}, g.vBytes...) } @@ -654,6 +684,8 @@ func (g *tamperCorruptGen) uint(bitSize int) uint32 { // bytes returns a random byte slice of length n if n <= 20, otherwise it returns a // a random byte slice of random length up to n. +// +//nolint:gosec func (g *tamperCorruptGen) bytes(n int) []byte { if n > 20 { n = g.r.Intn(n) @@ -661,5 +693,6 @@ func (g *tamperCorruptGen) bytes(n int) []byte { b := make([]byte, n) g.r.Read(b) + return b } diff --git a/actions/tamper_action_test.go b/actions/tamper_action_test.go index 4690d85..fabbb41 100644 --- a/actions/tamper_action_test.go +++ b/actions/tamper_action_test.go @@ -1,3 +1,4 @@ +//nolint:depguard,testpackage package actions import ( @@ -93,6 +94,7 @@ func TestTamperTCP(t *testing.T) { field TCPField valueGen tamperValueGen } + tests := []struct { name string args args diff --git a/common/packet.go b/common/packet.go index 0539baf..b2127d9 100644 --- a/common/packet.go +++ b/common/packet.go @@ -2,6 +2,7 @@ package common import ( "encoding/binary" + "fmt" "github.com/google/gopacket/layers" ) @@ -14,7 +15,7 @@ func UpdateTCPChecksum(tcp *layers.TCP) error { chksum, err := tcp.ComputeChecksum() if err != nil { - return err + return fmt.Errorf("failed to update TCP checksum: %w", err) } tcp.Checksum = chksum @@ -24,14 +25,13 @@ func UpdateTCPChecksum(tcp *layers.TCP) error { } // UpdateIPv4Checksum updates the IPv4 checksum field and the raw bytes for a gopacket IPv4 layer. -func UpdateIPv4Checksum(ip *layers.IPv4) error { +func UpdateIPv4Checksum(ip *layers.IPv4) { chksum := CalculateIPv4Checksum(ip.Contents) ip.Checksum = chksum binary.BigEndian.PutUint16(ip.Contents[10:12], chksum) - - return nil } +// CalculateIPv4Checksum calculates the IPv4 checksum for the given bytes. // copied from gopacket/layers/ip4.go because they didn't export one. for whatever some reason.. func CalculateIPv4Checksum(bytes []byte) uint16 { buf := make([]byte, len(bytes)) diff --git a/geneva_test.go b/geneva_test.go index ee4b3d8..3d7f070 100644 --- a/geneva_test.go +++ b/geneva_test.go @@ -36,10 +36,10 @@ var examples = []string{ func TestNewStrategy(t *testing.T) { t.Parallel() - for _, s := range examples { + for i, s := range examples { _, err := geneva.NewStrategy(s) if err != nil { - t.Errorf("failed to parse strategy: %v", err) + t.Errorf("failed to parse strategy %d %q: %v", i, s, err) } } }