Skip to content

Commit

Permalink
fix(firewall): delete chain rules by line number
Browse files Browse the repository at this point in the history
  • Loading branch information
qdm12 committed Aug 14, 2024
1 parent 09c47c7 commit 51c661f
Show file tree
Hide file tree
Showing 14 changed files with 1,100 additions and 64 deletions.
84 changes: 84 additions & 0 deletions internal/firewall/delete.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package firewall

import (
"context"
"fmt"
"os/exec"
"strings"
)

func isDeleteInstruction(instruction string) bool {
fields := strings.Fields(instruction)
for i := 0; i < len(fields); i += 2 {
switch fields[i] {
case "-D", "--delete": //nolint:goconst
return true
}
}
return false
}

func deleteIPTablesRule(ctx context.Context, iptablesBinary, instruction string,
runner Runner, logger Logger) (err error) {
targetRule, err := parseIptablesInstruction(instruction)
if err != nil {
return fmt.Errorf("parsing iptables command: %w", err)
}

lineNumber, err := findLineNumber(ctx, iptablesBinary,
targetRule, runner, logger)
if err != nil {
return fmt.Errorf("finding iptables chain rule line number: %w", err)
} else if lineNumber == 0 {
logger.Debug("rule matching \"" + instruction + "\" not found")
return nil
}
logger.Debug(fmt.Sprintf("found iptables chain rule matching %q at line number %d",
instruction, lineNumber))

cmd := exec.CommandContext(ctx, iptablesBinary, "-t", targetRule.table,
"-D", targetRule.chain, fmt.Sprint(lineNumber)) // #nosec G204
logger.Debug(cmd.String())
output, err := runner.Run(cmd)
if err != nil {
err = fmt.Errorf("command failed: %q: %w", cmd, err)
if output != "" {
err = fmt.Errorf("%w: %s", err, output)
}
return err
}

return nil
}

// findLineNumber finds the line number of an iptables rule.
// It returns 0 if the rule is not found.
func findLineNumber(ctx context.Context, iptablesBinary string,
instruction iptablesInstruction, runner Runner, logger Logger) (
lineNumber uint16, err error) {
listFlags := []string{"-t", instruction.table, "-L", instruction.chain,
"--line-numbers", "-n", "-v"}
cmd := exec.CommandContext(ctx, iptablesBinary, listFlags...) // #nosec G204
logger.Debug(cmd.String())
output, err := runner.Run(cmd)
if err != nil {
err = fmt.Errorf("command failed: %q: %w", cmd, err)
if output != "" {
err = fmt.Errorf("%w: %s", err, output)
}
return 0, err
}

chain, err := parseChain(output)
if err != nil {
return 0, fmt.Errorf("parsing chain list: %w", err)
}

for _, rule := range chain.rules {
if instruction.equalToRule(instruction.table, chain.name, rule) {
return rule.lineNumber, nil
}
}

return 0, nil
}
148 changes: 148 additions & 0 deletions internal/firewall/delete_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
package firewall

import (
"context"
"errors"
"testing"

"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
)

func newCmdMatcherListRules(iptablesBinary, table, chain string) *cmdMatcher { //nolint:unparam
return newCmdMatcher(iptablesBinary, "^-t$", "^"+table+"$", "^-L$", "^"+chain+"$",
"^--line-numbers$", "^-n$", "^-v$")
}

func Test_deleteIPTablesRule(t *testing.T) {
t.Parallel()

const iptablesBinary = "/sbin/iptables"
errTest := errors.New("test error")

testCases := map[string]struct {
instruction string
makeRunner func(ctrl *gomock.Controller) *MockRunner
makeLogger func(ctrl *gomock.Controller) *MockLogger
errWrapped error
errMessage string
}{
"invalid_instruction": {
instruction: "invalid",
errWrapped: ErrIptablesCommandMalformed,
errMessage: "parsing iptables command: iptables command is malformed: " +
"fields count 1 is not even: \"invalid\"",
},
"list_error": {
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
makeRunner: func(ctrl *gomock.Controller) *MockRunner {
runner := NewMockRunner(ctrl)
runner.EXPECT().
Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
Return("", errTest)
return runner
},
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
logger := NewMockLogger(ctrl)
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
return logger
},
errWrapped: errTest,
errMessage: `finding iptables chain rule line number: command failed: ` +
`"/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v": test error`,
},
"rule_not_found": {
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
makeRunner: func(ctrl *gomock.Controller) *MockRunner {
runner := NewMockRunner(ctrl)
runner.EXPECT().Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
Return(`Chain PREROUTING (policy ACCEPT 0 packets, 0 bytes)
num pkts bytes target prot opt in out source destination
1 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:5000 redir ports 9999`, //nolint:lll
nil)
return runner
},
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
logger := NewMockLogger(ctrl)
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
logger.EXPECT().Debug("rule matching \"-t nat --delete PREROUTING " +
"-i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678\" not found")
return logger
},
},
"rule_found_delete_error": {
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
makeRunner: func(ctrl *gomock.Controller) *MockRunner {
runner := NewMockRunner(ctrl)
runner.EXPECT().Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
Return("Chain PREROUTING (policy ACCEPT 0 packets, 0 bytes)\n"+
"num pkts bytes target prot opt in out source destination \n"+
"1 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:5000 redir ports 9999\n"+ //nolint:lll
"2 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:43716 redir ports 5678\n", //nolint:lll
nil)
runner.EXPECT().Run(newCmdMatcher(iptablesBinary, "^-t$", "^nat$",
"^-D$", "^PREROUTING$", "^2$")).Return("details", errTest)
return runner
},
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
logger := NewMockLogger(ctrl)
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
logger.EXPECT().Debug("found iptables chain rule matching \"-t nat --delete PREROUTING " +
"-i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678\" at line number 2")
logger.EXPECT().Debug("/sbin/iptables -t nat -D PREROUTING 2")
return logger
},
errWrapped: errTest,
errMessage: "command failed: \"/sbin/iptables -t nat -D PREROUTING 2\": test error: details",
},
"rule_found_delete_success": {
instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678",
makeRunner: func(ctrl *gomock.Controller) *MockRunner {
runner := NewMockRunner(ctrl)
runner.EXPECT().Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")).
Return("Chain PREROUTING (policy ACCEPT 0 packets, 0 bytes)\n"+
"num pkts bytes target prot opt in out source destination \n"+
"1 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:5000 redir ports 9999\n"+ //nolint:lll
"2 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:43716 redir ports 5678\n", //nolint:lll
nil)
runner.EXPECT().Run(newCmdMatcher(iptablesBinary, "^-t$", "^nat$",
"^-D$", "^PREROUTING$", "^2$")).Return("", nil)
return runner
},
makeLogger: func(ctrl *gomock.Controller) *MockLogger {
logger := NewMockLogger(ctrl)
logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v")
logger.EXPECT().Debug("found iptables chain rule matching \"-t nat --delete PREROUTING " +
"-i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678\" at line number 2")
logger.EXPECT().Debug("/sbin/iptables -t nat -D PREROUTING 2")
return logger
},
},
}

for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)

ctx := context.Background()
instruction := testCase.instruction
var runner *MockRunner
if testCase.makeRunner != nil {
runner = testCase.makeRunner(ctrl)
}
var logger *MockLogger
if testCase.makeLogger != nil {
logger = testCase.makeLogger(ctrl)
}

err := deleteIPTablesRule(ctx, iptablesBinary, instruction, runner, logger)

assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil {
assert.EqualError(t, err, testCase.errMessage)
}
})
}
}
13 changes: 13 additions & 0 deletions internal/firewall/interfaces.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package firewall

import "github.com/qdm12/golibs/command"

type Runner interface {
Run(cmd command.ExecCmd) (output string, err error)
}

type Logger interface {
Debug(s string)
Info(s string)
Error(s string)
}
8 changes: 6 additions & 2 deletions internal/firewall/ip6tables.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,14 @@ func (c *Config) runIP6tablesInstruction(ctx context.Context, instruction string
c.ip6tablesMutex.Lock() // only one ip6tables command at once
defer c.ip6tablesMutex.Unlock()

c.logger.Debug(c.ip6Tables + " " + instruction)
if isDeleteInstruction(instruction) {
return deleteIPTablesRule(ctx, c.ip6Tables, instruction,
c.runner, c.logger)
}

flags := strings.Fields(instruction)
cmd := exec.CommandContext(ctx, c.ip6Tables, flags...) // #nosec G204
c.logger.Debug(cmd.String())
if output, err := c.runner.Run(cmd); err != nil {
return fmt.Errorf("command failed: \"%s %s\": %s: %w",
c.ip6Tables, instruction, output, err)
Expand All @@ -55,7 +59,7 @@ var ErrPolicyNotValid = errors.New("policy is not valid")

func (c *Config) setIPv6AllPolicies(ctx context.Context, policy string) error {
switch policy {
case "ACCEPT", "DROP":
case "ACCEPT", "DROP": //nolint:goconst
default:
return fmt.Errorf("%w: %s", ErrPolicyNotValid, policy)
}
Expand Down
12 changes: 8 additions & 4 deletions internal/firewall/iptables.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,14 @@ func (c *Config) runIptablesInstruction(ctx context.Context, instruction string)
c.iptablesMutex.Lock() // only one iptables command at once
defer c.iptablesMutex.Unlock()

c.logger.Debug(c.ipTables + " " + instruction)
if isDeleteInstruction(instruction) {
return deleteIPTablesRule(ctx, c.ipTables, instruction,
c.runner, c.logger)
}

flags := strings.Fields(instruction)
cmd := exec.CommandContext(ctx, c.ipTables, flags...) // #nosec G204
c.logger.Debug(cmd.String())
if output, err := c.runner.Run(cmd); err != nil {
return fmt.Errorf("command failed: \"%s %s\": %s: %w",
c.ipTables, instruction, output, err)
Expand Down Expand Up @@ -109,8 +113,8 @@ func (c *Config) acceptInputThroughInterface(ctx context.Context, intf string, r

func (c *Config) acceptInputToSubnet(ctx context.Context, intf string,
destination netip.Prefix, remove bool) error {
interfaceFlag := "-i " + intf
if intf == "*" { // all interfaces
interfaceFlag := "-i " + intf //nolint:goconst
if intf == "*" { // all interfaces
interfaceFlag = ""
}

Expand Down Expand Up @@ -143,7 +147,7 @@ func (c *Config) acceptOutputTrafficToVPN(ctx context.Context,
defaultInterface string, connection models.Connection, remove bool) error {
protocol := connection.Protocol
if protocol == "tcp-client" {
protocol = "tcp"
protocol = "tcp" //nolint:goconst
}
instruction := fmt.Sprintf("%s OUTPUT -d %s -o %s -p %s -m %s --dport %d -j ACCEPT",
appendOrDelete(remove), connection.IP, defaultInterface, protocol,
Expand Down
Loading

0 comments on commit 51c661f

Please sign in to comment.