From b8d1dfe2e08b4971c0487f228003ae1421768f20 Mon Sep 17 00:00:00 2001 From: aliel Date: Mon, 9 Oct 2023 11:34:28 +0200 Subject: [PATCH] speedup dns detection using authoritative ns server --- src/aleph/sdk/domain.py | 61 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 59 insertions(+), 2 deletions(-) diff --git a/src/aleph/sdk/domain.py b/src/aleph/sdk/domain.py index b014f353..ff6a4999 100644 --- a/src/aleph/sdk/domain.py +++ b/src/aleph/sdk/domain.py @@ -1,6 +1,6 @@ import logging from enum import Enum -from ipaddress import IPv6Address +from ipaddress import IPv4Address, IPv6Address from typing import Dict, Iterable, List, NewType, Optional, Union from urllib.parse import urlparse @@ -22,6 +22,8 @@ class TargetType(str, Enum): def hostname_from_url(url: Union[HttpUrl, str]) -> Hostname: + """Extract FQDN from url""" + parsed = urlparse(url) if all([parsed.scheme, parsed.netloc]) is True: url = parsed.netloc @@ -39,6 +41,60 @@ class DomainValidator: def __init__(self): self.resolver = aiodns.DNSResolver(servers=settings.DNS_RESOLVERS) + async def get_ns_servers(self, hostname: Hostname): + """Get ns servers of a domain""" + dns_servers = settings.DNS_RESOLVERS + fqdn = hostname + + stop = False + while stop == False: + """**Detect and get authoritative NS server of subdomains if delegated**""" + try: + entries = await self.resolver.query(fqdn, "NS") + servers = [] + for entry in entries: + servers += await self.get_ipv6_addresses(entry.host) + servers += await self.get_ipv4_addresses(entry.host) + + dns_servers = servers + stop = True + except aiodns.error.DNSError: + sub_domains = fqdn.split(".") + if len(sub_domains) > 2: + fqdn = '.'.join(sub_domains[1:]) + continue + + if len(sub_domains) == 2: + stop = True + + return dns_servers + + async def get_resolver_for(self, hostname: Hostname): + dns_servers = await self.get_ns_servers(hostname) + return aiodns.DNSResolver(servers=dns_servers) + + async def get_target_type(self, fqdn: Hostname) -> Optional[TargetType]: + domain_validator = DomainValidator() + resolver = await domain_validator.get_resolver_for(fqdn) + try: + entry = await resolver.query(fqdn, "CNAME") + cname = getattr(entry, "cname") + if cname == settings.DNS_IPFS_DOMAIN: + return TargetType.IPFS + elif cname == settings.DNS_PROGRAM_DOMAIN: + return TargetType.PROGRAM + elif cname == settings.DNS_INSTANCE_DOMAIN: + return TargetType.INSTANCE + + return None + except aiodns.error.DNSError: + return None + + async def get_ipv4_addresses(self, hostname: Hostname) -> List[IPv4Address]: + """Returns all IPv4 addresses for a domain""" + entries: Iterable = await self.resolver.query(hostname, "A") or [] + return [entry.host for entry in entries] + async def get_ipv6_addresses(self, hostname: Hostname) -> List[IPv6Address]: """Returns all IPv6 addresses for a domain""" entries: Iterable = await self.resolver.query(hostname, "AAAA") or [] @@ -109,7 +165,8 @@ async def check_domain( record_value = dns_rule["dns"]["value"] try: - entries = await self.resolver.query(record_name, record_type.upper()) + resolver = await self.get_resolver_for(hostname) + entries = await resolver.query(record_name, record_type.upper()) except aiodns.error.DNSError: """Continue checks""" entries = None