Skip to content

Commit

Permalink
feat(natpmp): rpc error contain all failed attempt messages
Browse files Browse the repository at this point in the history
  • Loading branch information
qdm12 committed Jan 19, 2024
1 parent c826707 commit 9baa495
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 6 deletions.
3 changes: 2 additions & 1 deletion internal/natpmp/portmapping_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ func Test_Client_AddPortMapping(t *testing.T) {
initialConnectionDuration: time.Millisecond,
exchanges: []udpExchange{{close: true}},
err: ErrConnectionTimeout,
errMessage: "executing remote procedure call: connection timeout: after 1ms",
errMessage: "executing remote procedure call: connection timeout: failed attempts: " +
"read udp 127.0.0.1:[1-9][0-9]{0,4}->127.0.0.1:[1-9][0-9]{0,4}: i/o timeout \\(try 1\\)",
},
"add_udp": {
ctx: context.Background(),
Expand Down
61 changes: 56 additions & 5 deletions internal/natpmp/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"fmt"
"net"
"net/netip"
"sort"
"strings"
"time"
)

Expand Down Expand Up @@ -65,9 +67,8 @@ func (c *Client) rpc(ctx context.Context, gateway netip.Addr,
// Note it does not double if the source IP mismatches the gateway IP.
connectionDuration := c.initialConnectionDuration

var totalRetryDuration time.Duration

var retryCount uint
var failedAttempts []string
for retryCount = 0; retryCount < c.maxRetries; retryCount++ {
deadline := time.Now().Add(connectionDuration)
err = connection.SetDeadline(deadline)
Expand All @@ -87,8 +88,8 @@ func (c *Client) rpc(ctx context.Context, gateway netip.Addr,
}
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
totalRetryDuration += connectionDuration
connectionDuration *= 2
failedAttempts = append(failedAttempts, netErr.Error())
continue
}
return nil, fmt.Errorf("reading from udp connection: %w", err)
Expand All @@ -98,6 +99,9 @@ func (c *Client) rpc(ctx context.Context, gateway netip.Addr,
// Upon receiving a response packet, the client MUST check the source IP
// address, and silently discard the packet if the address is not the
// address of the gateway to which the request was sent.
failedAttempts = append(failedAttempts,
fmt.Sprintf("received response from %s instead of gateway IP %s",
receivedRemoteAddress.IP, gatewayAddress.IP))
continue
}

Expand All @@ -106,8 +110,8 @@ func (c *Client) rpc(ctx context.Context, gateway netip.Addr,
}

if retryCount == c.maxRetries {
return nil, fmt.Errorf("%w: after %s",
ErrConnectionTimeout, totalRetryDuration)
return nil, fmt.Errorf("%w: failed attempts: %s",
ErrConnectionTimeout, dedupFailedAttempts(failedAttempts))
}

// Opcodes between 0 and 127 are client requests. Opcodes from 128 to
Expand All @@ -121,3 +125,50 @@ func (c *Client) rpc(ctx context.Context, gateway netip.Addr,

return response, nil
}

func dedupFailedAttempts(failedAttempts []string) (errorMessage string) {
type data struct {
message string
indices []int
}
messageToData := make(map[string]data, len(failedAttempts))
for i, message := range failedAttempts {
metadata, ok := messageToData[message]
if !ok {
metadata.message = message
}
metadata.indices = append(metadata.indices, i)
sort.Slice(metadata.indices, func(i, j int) bool {
return metadata.indices[i] < metadata.indices[j]
})
messageToData[message] = metadata
}

// Sort by first index
dataSlice := make([]data, 0, len(messageToData))
for _, metadata := range messageToData {
dataSlice = append(dataSlice, metadata)
}
sort.Slice(dataSlice, func(i, j int) bool {
return dataSlice[i].indices[0] < dataSlice[j].indices[0]
})

dedupedFailedAttempts := make([]string, 0, len(dataSlice))
for _, data := range dataSlice {
newMessage := fmt.Sprintf("%s (%s)", data.message,
indicesToTryString(data.indices))
dedupedFailedAttempts = append(dedupedFailedAttempts, newMessage)
}
return strings.Join(dedupedFailedAttempts, "; ")
}

func indicesToTryString(indices []int) string {
if len(indices) == 1 {
return fmt.Sprintf("try %d", indices[0]+1)
}
tries := make([]string, len(indices))
for i, index := range indices {
tries[i] = fmt.Sprintf("%d", index+1)
}
return fmt.Sprintf("tries %s", strings.Join(tries, ", "))
}
37 changes: 37 additions & 0 deletions internal/natpmp/rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,40 @@ func Test_Client_rpc(t *testing.T) {
})
}
}

func Test_dedupFailedAttempts(t *testing.T) {
t.Parallel()

testCases := map[string]struct {
failedAttempts []string
expected string
}{
"empty": {},
"single_attempt": {
failedAttempts: []string{"test"},
expected: "test (try 1)",
},
"multiple_same_attempts": {
failedAttempts: []string{"test", "test", "test"},
expected: "test (tries 1, 2, 3)",
},
"multiple_different_attempts": {
failedAttempts: []string{"test1", "test2", "test3"},
expected: "test1 (try 1); test2 (try 2); test3 (try 3)",
},
"soup_mix": {
failedAttempts: []string{"test1", "test2", "test1", "test3", "test2"},
expected: "test1 (tries 1, 3); test2 (tries 2, 5); test3 (try 4)",
},
}

for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()

actual := dedupFailedAttempts(testCase.failedAttempts)
assert.Equal(t, testCase.expected, actual)
})
}
}

0 comments on commit 9baa495

Please sign in to comment.