Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dns/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"asyncbackend",
"asyncquery",
"asyncresolver",
"asyncserver",
"btree",
"btreezone",
"dnssec",
Expand Down
9 changes: 9 additions & 0 deletions dns/_asyncbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,15 @@ async def make_socket(
):
raise NotImplementedError

async def serve(
self,
client_connected_cb,
af,
socktype,
addr,
):
raise NotImplementedError

def datagram_connection_required(self):
return False

Expand Down
26 changes: 26 additions & 0 deletions dns/_asyncio_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,32 @@ async def make_socket(
"unsupported socket " + f"type {socktype}"
) # pragma: no cover

async def serve(
self,
client_connected_cb,
af,
socktype,
addr,
):
if socktype == socket.SOCK_DGRAM:
sock = await self.make_socket(af, socket.SOCK_DGRAM, 0, addr)
await client_connected_cb(sock)
elif socktype == socket.SOCK_STREAM:
async def handle_tcp(r, w):
sock_tcp = _StreamSocket(af, r, w)
await client_connected_cb(sock_tcp)
hostname, port = addr
server = await asyncio.start_server(
handle_tcp,
host=hostname,
port=port,
family=af,
)
await server.serve_forever()
raise NotImplementedError(
"unsupported socket " + f"type {socktype}"
) # pragma: no cover

async def sleep(self, interval):
await asyncio.sleep(interval)

Expand Down
24 changes: 24 additions & 0 deletions dns/_trio_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,30 @@ async def make_socket(
"unsupported socket " + f"type {socktype}"
) # pragma: no cover

async def serve(
self,
client_connected_cb,
af,
socktype,
addr,
):
if socktype == socket.SOCK_DGRAM:
sock = await self.make_socket(af, socket.SOCK_DGRAM, 0, addr)
await client_connected_cb(sock)
elif socktype == socket.SOCK_STREAM:
async def handle_tcp(stream):
sock_tcp = StreamSocket(af, stream)
await client_connected_cb(sock_tcp)
hostname, port = addr
await trio.serve_tcp(
handle_tcp,
host=hostname,
port=port,
)
raise NotImplementedError(
"unsupported socket " + f"type {socktype}"
) # pragma: no cover

async def sleep(self, interval):
await trio.sleep(interval)

Expand Down
108 changes: 108 additions & 0 deletions dns/asyncserver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import socket
from typing import Awaitable, Callable, Never

import dns.asyncbackend
import dns.asyncquery
import dns.exception
import dns.inet
import dns.message
import dns.name
import dns.rcode
import dns.tsig


def _rcode_from_exception(e: Exception) -> dns.rcode.Rcode:
"""Get rcode for exception"""
if isinstance(e, dns.exception.FormError):
return dns.rcode.FORMERR
elif isinstance(e, dns.exception.SyntaxError):
return dns.rcode.SERVFAIL
elif isinstance(e, dns.exception.UnexpectedEnd):
return dns.rcode.BADTRUNC
elif isinstance(e, dns.exception.TooBig):
return dns.rcode.BADTRUNC
elif isinstance(e, dns.exception.Timeout):
return dns.rcode.SERVFAIL
elif isinstance(e, dns.exception.UnsupportedAlgorithm):
return dns.rcode.BADALG
elif isinstance(e, dns.exception.AlgorithmKeyMismatch):
return dns.rcode.BADALG
elif isinstance(e, dns.exception.ValidationFailure):
return dns.rcode.SERVFAIL
elif isinstance(e, dns.exception.DeniedByPolicy):
return dns.rcode.REFUSED
elif isinstance(e, NotImplementedError):
return dns.rcode.NOTIMP
return dns.rcode.SERVFAIL


async def udp_serve(
cb: Callable[[dns.message.Message, str], Awaitable[dns.message.Message]],
host: str,
port: int = 53,
keyring: dict[dns.name.Name, dns.tsig.Key] | None = None,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
ignore_errors: bool = False,
backend: dns.asyncbackend.Backend | None = None,
) -> None:
async def handle_udp(sock: dns.asyncbackend.DatagramSocket):
while True:
(m, _, from_address) = await dns.asyncquery.receive_udp(
sock,
one_rr_per_rrset=one_rr_per_rrset,
keyring=keyring,
ignore_trailing=ignore_trailing,
ignore_errors=ignore_errors,
)
try:
r = await cb(m, from_address)
except Exception as e:
r = dns.message.make_response(m)
r.set_rcode(_rcode_from_exception(e))
wire = r.to_wire()
await dns.asyncquery.send_udp(sock, wire, from_address)

if not backend:
backend = dns.asyncbackend.get_default_backend()
af = dns.inet.af_for_address(host)
addr = (host, port)
await backend.serve(handle_udp, af, socket.SOCK_DGRAM, addr)


async def tcp_serve(
cb: Callable[[dns.message.Message, str], Awaitable[dns.message.Message]],
host: str,
port: int = 53,
keyring: dict[dns.name.Name, dns.tsig.Key] | None = None,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
ignore_errors: bool = False,
backend: dns.asyncbackend.Backend | None = None,
) -> None:
async def handle_tcp(sock: dns.asyncbackend.StreamSocket):
peer_address = await sock.getpeername()
while True:
try:
(m, _) = await dns.asyncquery.receive_tcp(
sock,
one_rr_per_rrset=one_rr_per_rrset,
keyring=keyring,
ignore_trailing=ignore_trailing,
ignore_errors=ignore_errors,
)
try:
r = await cb(m, peer_address)
except Exception as e:
r = dns.message.make_response(m)
r.set_rcode(_rcode_from_exception(e))
wire = r.to_wire()
await dns.asyncquery.send_tcp(sock, wire)
except EOFError:
break

if not backend:
backend = dns.asyncbackend.get_default_backend()
af = dns.inet.af_for_address(host)
addr = (host, port)
await backend.serve(handle_tcp, af, socket.SOCK_STREAM, addr)
115 changes: 40 additions & 75 deletions examples/ddns_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

import asyncio
import logging
import struct
import typing

import dns.asyncserver
import dns.exception
import dns.message
import dns.name
Expand All @@ -30,45 +30,27 @@
}


def response(msg, code=dns.rcode.SERVFAIL):
response = dns.message.make_response(msg)
response.set_rcode(code)
return response.to_wire()


async def handle_nsupdate(data, addr):
async def handle_nsupdate(msg: dns.message.Message, addr):
cli = addr[0]
msg = dns.message.from_wire(data, keyring=KEYRING)
try:
if msg.opcode() != dns.opcode.UPDATE:
raise NotImplementedError("Opcode %s not implemented" % dns.opcode.to_text(msg.opcode()))
update_msg = typing.cast(dns.update.UpdateMessage, msg)
zone = update_msg.zone[0].name
if not msg.had_tsig or msg.keyname not in TEST_ZONES[zone]:
raise dns.exception.DeniedByPolicy(f"Key {msg.keyname} not allowed for zone {zone}")
for r in update_msg.update:
if r.deleting:
if r.deleting == dns.rdataclass.ANY and r.rdtype == dns.rdatatype.ANY:
logging.info("%s: delete_all_rrsets %s" % (cli, r))
elif r.deleting == dns.rdataclass.ANY:
logging.info("%s: delete_rrset %s" % (cli, r))
elif r.deleting == dns.rdataclass.NONE:
logging.info("%s: delete_from_rrset %s" % (cli, r))
else:
logging.info("%s: add_to_rrset %s" % (cli, r))
except dns.exception.FormError:
logging.exception("Rejected %s: Error parsing message" % cli)
return response(msg, code=dns.rcode.FORMERR)
except dns.exception.DeniedByPolicy:
logging.exception("Rejected %s: Validation error" % cli)
return response(msg, code=dns.rcode.REFUSED)
except NotImplementedError:
logging.exception("Rejected %s: Not implemented error" % cli)
return response(msg, code=dns.rcode.NOTIMP)
except:
logging.exception("Rejected %s: Internal error" % cli)
return response(msg, code=dns.rcode.SERVFAIL)
return response(msg, code=dns.rcode.NOERROR)
if msg.opcode() != dns.opcode.UPDATE:
raise NotImplementedError("Opcode %s not implemented" % dns.opcode.to_text(msg.opcode()))
update_msg = typing.cast(dns.update.UpdateMessage, msg)
zone = update_msg.zone[0].name
if not msg.had_tsig or msg.keyname not in TEST_ZONES[zone]:
raise dns.exception.ValidationFailure(f"Key {msg.keyname} not allowed for zone {zone}")
for r in update_msg.update:
if r.deleting:
if r.deleting == dns.rdataclass.ANY and r.rdtype == dns.rdatatype.ANY:
logging.info("%s: delete_all_rrsets %s" % (cli, r))
elif r.deleting == dns.rdataclass.ANY:
logging.info("%s: delete_rrset %s" % (cli, r))
elif r.deleting == dns.rdataclass.NONE:
logging.info("%s: delete_from_rrset %s" % (cli, r))
else:
logging.info("%s: add_to_rrset %s" % (cli, r))
response = dns.message.make_response(msg)
response.set_rcode(dns.rcode.NOERROR)
return response


async def main():
Expand All @@ -77,43 +59,26 @@ async def main():

logging.basicConfig(level=logging.INFO)
logging.info(f"Starting servers at {hostname}:{port}")
loop = asyncio.get_event_loop()

# Start UDP server
class DatagramProtocol(asyncio.DatagramProtocol):
def connection_made(self, transport):
self.transport = transport

def datagram_received(self, data, addr):
asyncio.ensure_future(self.handle(data, addr))

async def handle(self, data, addr):
result = await handle_nsupdate(data, addr)
self.transport.sendto(result, addr)

transport, _protocol = await loop.create_datagram_endpoint(lambda: DatagramProtocol(), local_addr=(hostname, port))

# Start TCP server
class StreamReaderProtocol(asyncio.StreamReaderProtocol):
def __init__(self):
super().__init__(asyncio.StreamReader(), self.handle_tcp)

async def handle_tcp(self, reader, writer):
addr = writer.transport.get_extra_info("peername")
while True:
try:
(size,) = struct.unpack("!H", await reader.readexactly(2))
except asyncio.IncompleteReadError:
break
data = await reader.readexactly(size)

result = await handle_nsupdate(data, addr)
bsize = struct.pack("!H", len(result))
writer.write(bsize)
writer.write(result)

server = await loop.create_server(lambda: StreamReaderProtocol(), hostname, port)
await server.serve_forever()
async with asyncio.TaskGroup() as tg:
tg.create_task(
dns.asyncserver.udp_serve(
handle_nsupdate,
hostname,
port,
KEYRING,
one_rr_per_rrset=True,
),
)
tg.create_task(
dns.asyncserver.tcp_serve(
handle_nsupdate,
hostname,
port,
KEYRING,
one_rr_per_rrset=True,
),
)


asyncio.run(main())