Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -136,47 +136,91 @@ package enum PeerAddress: Equatable {
// - ipv4:<host>:<port> for ipv4 addresses
// - ipv6:[<host>]:<port> for ipv6 addresses
// - unix:<uds-pathname> for UNIX domain sockets
let addressUTF8View = address.utf8

// First get the first component so that we know what type of address we're dealing with
let addressComponents = address.split(separator: ":", maxSplits: 1)
let firstColonIndex = addressUTF8View.firstIndex(of: UInt8(ascii: ":"))

guard addressComponents.count > 1 else {
guard let firstColonIndex else {
// This is some unexpected/unknown format
return nil
}

let addressType = addressUTF8View[..<firstColonIndex]

var addressWithoutType = addressUTF8View[firstColonIndex...]
addressWithoutType.removeFirst()

// Check what type the transport is...
switch addressComponents[0] {
case "ipv4":
let ipv4AddressComponents = addressComponents[1].split(separator: ":")
if ipv4AddressComponents.count == 2, let port = Int(ipv4AddressComponents[1]) {
self = .ipv4(address: String(ipv4AddressComponents[0]), port: port)
} else {
if addressType.elementsEqual("ipv4".utf8) {
guard let addressColon = addressWithoutType.firstIndex(of: UInt8(ascii: ":")) else {
// This is some unexpected/unknown format
return nil
}

case "ipv6":
if addressComponents[1].first == "[" {
// At this point, we are looking at an address with format: [<address>]:<port>
// We drop the first character ('[') and split by ']:' to keep two components: the address
// and the port.
let ipv6AddressComponents = addressComponents[1].dropFirst().split(separator: "]:")
if ipv6AddressComponents.count == 2, let port = Int(ipv6AddressComponents[1]) {
self = .ipv6(address: String(ipv6AddressComponents[0]), port: port)
} else {
return nil
}
let hostComponent = addressWithoutType[..<addressColon]
var portComponent = addressWithoutType[addressColon...]
portComponent.removeFirst()

if let host = String(hostComponent), let port = Int(ipAddressPortStringBytes: portComponent) {
self = .ipv4(address: host, port: port)
} else {
return nil
}
} else if addressType.elementsEqual("ipv6".utf8) {
guard let lastColonIndex = addressWithoutType.lastIndex(of: UInt8(ascii: ":")) else {
// This is some unexpected/unknown format
return nil
}

case "unix":
// Whatever comes after "unix:" is the <pathname>
self = .unixDomainSocket(path: String(addressComponents[1]))
var hostComponent = addressWithoutType[..<lastColonIndex]
var portComponent = addressWithoutType[lastColonIndex...]
portComponent.removeFirst()

default:
if let firstBracket = hostComponent.popFirst(), let lastBracket = hostComponent.popLast(),
firstBracket == UInt8(ascii: "["), lastBracket == UInt8(ascii: "]"),
let host = String(hostComponent), let port = Int(ipAddressPortStringBytes: portComponent)
{
self = .ipv6(address: host, port: port)
} else {
// This is some unexpected/unknown format
return nil
}
} else if addressType.elementsEqual("unix".utf8) {
// Whatever comes after "unix:" is the <pathname>
self = .unixDomainSocket(path: String(addressWithoutType) ?? "")
} else {
// This is some unexpected/unknown format
return nil
}
}
}

extension Int {
package init?(ipAddressPortStringBytes: some Collection<UInt8>) {
guard (1 ... 5).contains(ipAddressPortStringBytes.count) else {
// Valid IP port values go up to 2^16-1 (65535), which is 5 digits long.
// If the string we get is over 5 characters, we know for sure that this is an invalid port.
// If it's empty, we also know it's invalid as we need at least one digit.
return nil
}

var value = 0
for utf8Char in ipAddressPortStringBytes {
value &*= 10
guard (UInt8(ascii: "0") ... UInt8(ascii: "9")).contains(utf8Char) else {
// non-digit character
return nil
}
value &+= Int(utf8Char - UInt8(ascii: "0"))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can do unchecked here because we checked the range of values above

Suggested change
value &+= Int(utf8Char - UInt8(ascii: "0"))
value &+= Int(utf8Char &- UInt8(ascii: "0"))

}

guard value <= Int(UInt16.max) else {
// Valid IP port values go up to 2^16-1.
// If a number greater than this was given, it can't be a valid port.
return nil
}

self = value
}
}
22 changes: 22 additions & 0 deletions Tests/GRPCOTelTracingInterceptorsTests/PeerAddressTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,26 @@ struct PeerAddressTests {
let address = PeerAddress(address)
#expect(address == nil)
}

@Test(
"Int.init(utf8View:)",
arguments: [
("1", 1),
("21", 21),
("321", 321),
("4321", 4321),
("54321", 54321),
("65536", nil), // Invalid: over 65535 IP port limit
("654321", nil), // Invalid: over 5 digits
("abc", nil), // Invalid: no digits
("a123", nil), // Invalid: mixed digits and chars outside the valid ascii range for digits
("123a", nil), // Invalid: mixed digits and chars outside the valid ascii range for digits
("(123", nil), // Invalid: mixed digits and chars outside the valid ascii range for digits
("123(", nil), // Invalid: mixed digits and chars outside the valid ascii range for digits
("", nil), // Invalid: empty string
]
)
func testIntInitFromUTF8View(string: String, expectedInt: Int?) async throws {
#expect(expectedInt == Int(ipAddressPortStringBytes: string.utf8))
}
}