From 0ea5ad0a4583e1f519b9bcc67cfac381230d9cf2 Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Sun, 18 Feb 2024 10:27:43 -0800 Subject: [PATCH] The Tudoor fix should not eat valid Truncated exceptions [#1053] (#1054) * The Tudoor fix should not eat valid Truncated exceptions [##1053] * Make logic more readable (cherry picked from commit 2ab3d1628c9ae0545e225522b3b445c3478dc6ad) --- dns/asyncquery.py | 10 ++++++++ dns/query.py | 14 +++++++++++ tests/test_async.py | 60 ++++++++++++++++++++++++++++++++++++++++++++- tests/test_query.py | 44 ++++++++++++++++++++++++++++++++- 4 files changed, 126 insertions(+), 2 deletions(-) diff --git a/dns/asyncquery.py b/dns/asyncquery.py index 94cb2413..4d9ab9ae 100644 --- a/dns/asyncquery.py +++ b/dns/asyncquery.py @@ -151,6 +151,16 @@ async def receive_udp( ignore_trailing=ignore_trailing, raise_on_truncation=raise_on_truncation, ) + except dns.message.Truncated as e: + # See the comment in query.py for details. + if ( + ignore_errors + and query is not None + and not query.is_response(e.message()) + ): + continue + else: + raise except Exception: if ignore_errors: continue diff --git a/dns/query.py b/dns/query.py index bdd251e7..f0ee9161 100644 --- a/dns/query.py +++ b/dns/query.py @@ -638,6 +638,20 @@ def receive_udp( ignore_trailing=ignore_trailing, raise_on_truncation=raise_on_truncation, ) + except dns.message.Truncated as e: + # If we got Truncated and not FORMERR, we at least got the header with TC + # set, and very likely the question section, so we'll re-raise if the + # message seems to be a response as we need to know when truncation happens. + # We need to check that it seems to be a response as we don't want a random + # injected message with TC set to cause us to bail out. + if ( + ignore_errors + and query is not None + and not query.is_response(e.message()) + ): + continue + else: + raise except Exception: if ignore_errors: continue diff --git a/tests/test_async.py b/tests/test_async.py index ba2078cd..9373548d 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -705,7 +705,11 @@ async def mock_receive( from2, ignore_unexpected=True, ignore_errors=True, + raise_on_truncation=False, + good_r=None, ): + if good_r is None: + good_r = self.good_r s = MockSock(wire1, from1, wire2, from2) (r, when, _) = await dns.asyncquery.receive_udp( s, @@ -713,9 +717,10 @@ async def mock_receive( time.time() + 2, ignore_unexpected=ignore_unexpected, ignore_errors=ignore_errors, + raise_on_truncation=raise_on_truncation, query=self.q, ) - self.assertEqual(r, self.good_r) + self.assertEqual(r, good_r) def test_good_mock(self): async def run(): @@ -802,6 +807,59 @@ async def run(): self.async_run(run) + def test_good_wire_with_truncation_flag_and_no_truncation_raise(self): + async def run(): + tc_r = dns.message.make_response(self.q) + tc_r.flags |= dns.flags.TC + tc_r_wire = tc_r.to_wire() + await self.mock_receive( + tc_r_wire, ("127.0.0.1", 53), None, None, good_r=tc_r + ) + + self.async_run(run) + + def test_good_wire_with_truncation_flag_and_truncation_raise(self): + async def agood(): + tc_r = dns.message.make_response(self.q) + tc_r.flags |= dns.flags.TC + tc_r_wire = tc_r.to_wire() + await self.mock_receive( + tc_r_wire, ("127.0.0.1", 53), None, None, raise_on_truncation=True + ) + + def good(): + self.async_run(agood) + + self.assertRaises(dns.message.Truncated, good) + + def test_wrong_id_wire_with_truncation_flag_and_no_truncation_raise(self): + async def run(): + bad_r = dns.message.make_response(self.q) + bad_r.id += 1 + bad_r.flags |= dns.flags.TC + bad_r_wire = bad_r.to_wire() + await self.mock_receive( + bad_r_wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53) + ) + + self.async_run(run) + + def test_wrong_id_wire_with_truncation_flag_and_truncation_raise(self): + async def run(): + bad_r = dns.message.make_response(self.q) + bad_r.id += 1 + bad_r.flags |= dns.flags.TC + bad_r_wire = bad_r.to_wire() + await self.mock_receive( + bad_r_wire, + ("127.0.0.1", 53), + self.good_r_wire, + ("127.0.0.1", 53), + raise_on_truncation=True, + ) + + self.async_run(run) + def test_bad_wire_not_ignored(self): bad_r = dns.message.make_response(self.q) bad_r.id += 1 diff --git a/tests/test_query.py b/tests/test_query.py index 1039a14e..62007e85 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -29,6 +29,7 @@ have_ssl = False import dns.exception +import dns.flags import dns.inet import dns.message import dns.name @@ -706,7 +707,11 @@ def mock_receive( from2, ignore_unexpected=True, ignore_errors=True, + raise_on_truncation=False, + good_r=None, ): + if good_r is None: + good_r = self.good_r s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) try: with mock_udp_recv(wire1, from1, wire2, from2): @@ -716,9 +721,10 @@ def mock_receive( time.time() + 2, ignore_unexpected=ignore_unexpected, ignore_errors=ignore_errors, + raise_on_truncation=raise_on_truncation, query=self.q, ) - self.assertEqual(r, self.good_r) + self.assertEqual(r, good_r) finally: s.close() @@ -787,6 +793,42 @@ def test_bad_wire(self): bad_r_wire[:10], ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53) ) + def test_good_wire_with_truncation_flag_and_no_truncation_raise(self): + tc_r = dns.message.make_response(self.q) + tc_r.flags |= dns.flags.TC + tc_r_wire = tc_r.to_wire() + self.mock_receive(tc_r_wire, ("127.0.0.1", 53), None, None, good_r=tc_r) + + def test_good_wire_with_truncation_flag_and_truncation_raise(self): + def good(): + tc_r = dns.message.make_response(self.q) + tc_r.flags |= dns.flags.TC + tc_r_wire = tc_r.to_wire() + self.mock_receive( + tc_r_wire, ("127.0.0.1", 53), None, None, raise_on_truncation=True + ) + + self.assertRaises(dns.message.Truncated, good) + + def test_wrong_id_wire_with_truncation_flag_and_no_truncation_raise(self): + bad_r = dns.message.make_response(self.q) + bad_r.id += 1 + bad_r.flags |= dns.flags.TC + bad_r_wire = bad_r.to_wire() + self.mock_receive( + bad_r_wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53) + ) + + def test_wrong_id_wire_with_truncation_flag_and_truncation_raise(self): + bad_r = dns.message.make_response(self.q) + bad_r.id += 1 + bad_r.flags |= dns.flags.TC + bad_r_wire = bad_r.to_wire() + self.mock_receive( + bad_r_wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53), + raise_on_truncation=True + ) + def test_bad_wire_not_ignored(self): bad_r = dns.message.make_response(self.q) bad_r.id += 1