Skip to content

Commit

Permalink
Use case insensitive comparison for selection of protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
TomasKorbar committed Oct 18, 2024
1 parent f478774 commit 761c5fd
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 11 deletions.
7 changes: 5 additions & 2 deletions dnsconfd/configuration/static_servers_option.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@ def validate(self, value) -> bool:
"specification")
return False
if "protocol" in resolver:
if (resolver["protocol"] != "DoT"
and resolver["protocol"] != "plain"):
if not isinstance(resolver["protocol"], str):
self.lgr.error("protocol has to be a string")
return False
if (resolver["protocol"].lower() != "dot"
and resolver["protocol"].lower() != "plain"):
self.lgr.error("protocol contains invalid value")
return False
if "port" in resolver:
Expand Down
4 changes: 2 additions & 2 deletions dnsconfd/input_modules/dnsconfd_dbus_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,12 @@ def Update(self, servers: list[dict[str, typing.Any]], mode: int) \
protocol = DnsProtocol.PLAIN
if server.get("protocol", None) is not None:
if (not isinstance(server["protocol"], str)
or server["protocol"] not in ["plain", "DoT"]):
or server["protocol"].lower() not in ["plain", "dot"]):
msg = f"{index + 1}. server has unknown protocol " \
f"{server["protocol"]}, only plain or DoT allowed"
self.lgr.error(msg)
return False, msg
if server["protocol"] == "DoT":
if server["protocol"].lower() == "dot":
protocol = DnsProtocol.DNS_OVER_TLS

name = None
Expand Down
39 changes: 32 additions & 7 deletions dnsconfd/network_objects/server_description.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,26 @@ def __init__(self,
""" Object holding information about DNS server
:param address_family: Indicates whether this is IPV4 of IPV6
:type address_family: int
:param address: Address of server
:type address: bytes
:param port: Port the server is listening on, defaults to None
:type port: int, Optional
:param name: Server name indication, when TLS is used, defaults to None
:type name: str, Optional
:param priority: Priority of this server. Higher priority means server
will be used instead of lower priority ones, defaults
to 50
:type priority: int
:param routing_domains: domains whose members will be resolved only by
this or other servers with the same domain
entry
:param search_domains: domains that should be used for host-name
lookup
:param interface: indicating if server can be used only through
interface with this interface index
:param protocol: protocol that should be used for communication with
this server
:param dnssec: indicating whether this server supports dnssec or not
:param networks: networks whose reverse dns records must be resolved
by this server
:param firewall_zone: name of firewall zone that this server should be
associated with
"""
self.address_family = address_family
self.address = bytes(address)
Expand Down Expand Up @@ -145,12 +154,22 @@ def __str__(self) -> str:
"""
return self.to_unbound_string()

def is_family(self, family):
def is_family(self, family: int) -> bool:
""" Get whether this server is of specified IP family
:param family: 4 or 6
:return: True if this object is of the same family as specified,
otherwise False
"""
if self.address_family == socket.AF_INET and family == 4:
return True
return False

def to_dict(self):
""" Get dictionary representing values held by this object
:return: dictionary representation of this object
"""
if self.port:
port = self.port
elif self.protocol == DnsProtocol.DNS_OVER_TLS:
Expand All @@ -168,7 +187,13 @@ def to_dict(self):
"networks": [str(x) for x in self.networks],
"firewall_zone": self.firewall_zone}

def get_rev_zones(self):
def get_rev_zones(self) -> list[str]:
""" Get domains that this server should handle according to its
networks
:return: list of strings containing reverse domains belonging to
networks
"""
zones = []
for net in self.networks:
mem_bits = 8 if net.version == 4 else 4
Expand Down

0 comments on commit 761c5fd

Please sign in to comment.