diff --git a/checkdmarc.py b/checkdmarc.py index 85133a6..0178226 100644 --- a/checkdmarc.py +++ b/checkdmarc.py @@ -828,7 +828,7 @@ def _get_txt_records(domain, nameservers=None, resolver=None, timeout=2.0): return records -def _query_dmarc_record(domain, nameservers=None, resolver=None, timeout=2.0): +def _query_dmarc_record(domain, nameservers=None, resolver=None, timeout=2.0, raise_for_unrelated_records=True): """ Queries DNS for a DMARC record @@ -846,12 +846,13 @@ def _query_dmarc_record(domain, nameservers=None, resolver=None, timeout=2.0): dmarc_record = None dmarc_record_count = 0 unrelated_records = [] + dmarc_prefix = "v=DMARC1" try: records = _query_dns(target, "TXT", nameservers=nameservers, resolver=resolver, timeout=timeout) for record in records: - if record.startswith("v=DMARC1"): + if record.startswith(dmarc_prefix): dmarc_record_count += 1 elif record.strip().startswith("v=DMARC1"): raise DMARCRecordStartsWithWhitespace( @@ -867,12 +868,13 @@ def _query_dmarc_record(domain, nameservers=None, resolver=None, timeout=2.0): "Multiple DMARC policy records are not permitted - " "https://tools.ietf.org/html/rfc7489#section-6.6.3") if len(unrelated_records) > 0: - raise UnrelatedTXTRecordFoundAtDMARC( - "Unrelated TXT records were discovered. These should be " - "removed, as some receivers may not expect to find " - "unrelated TXT records " - "at {0}\n\n{1}".format(target, "\n\n".join(unrelated_records))) - dmarc_record = records[0] + if raise_for_unrelated_records: + raise UnrelatedTXTRecordFoundAtDMARC( + "Unrelated TXT records were discovered. These should be " + "removed, as some receivers may not expect to find " + "unrelated TXT records " + "at {0}\n\n{1}".format(target, "\n\n".join(unrelated_records))) + dmarc_record = [record for record in records if record.startswith(dmarc_prefix)][0] except dns.resolver.NoAnswer: try: @@ -969,7 +971,7 @@ def _query_bmi_record(domain, selector="default", nameservers=None, return bimi_record -def query_dmarc_record(domain, nameservers=None, resolver=None, timeout=2.0): +def query_dmarc_record(domain, nameservers=None, resolver=None, timeout=2.0, raise_for_unrelated_records=True): """ Queries DNS for a DMARC record @@ -998,7 +1000,7 @@ def query_dmarc_record(domain, nameservers=None, resolver=None, timeout=2.0): base_domain = get_base_domain(domain) location = domain.lower() record = _query_dmarc_record(domain, nameservers=nameservers, - resolver=resolver, timeout=timeout) + resolver=resolver, timeout=timeout, raise_for_unrelated_records=raise_for_unrelated_records) try: root_records = _query_dns(domain.lower(), "TXT", nameservers=nameservers, resolver=resolver,