Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pydnsbl/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,7 @@
from .checker import DNSBLChecker, DNSBLDomainChecker, DNSBLIpChecker

__all__ = (
"DNSBLChecker",
"DNSBLDomainChecker",
"DNSBLIpChecker"
)
164 changes: 107 additions & 57 deletions pydnsbl/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,28 @@
print(result.categories)
print(result.detected_by)
"""

import abc
import asyncio
import idna
import ipaddress
import re
import sys
import threading
import warnings
from typing import Iterable, NamedTuple, Optional, Union

import aiodns
import idna
from pycares import ares_query_a_result

from .providers import Provider, BASE_PROVIDERS, BASE_DOMAIN_PROVIDERS
from .providers import BASE_DOMAIN_PROVIDERS, BASE_PROVIDERS, Provider

if sys.platform == 'win32' and sys.version_info >= (3, 8):
if sys.platform == "win32" and sys.version_info >= (3, 8):
# fixes https://github.com/dmippolitov/pydnsbl/issues/12
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
# Some users may want to use winloop which is a windows version of uvloop
if asyncio.DefaultEventLoopPolicy.__name__ == "WindowsProactorEventLoopPolicy":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())


class DNSBLResult:
"""
Expand All @@ -37,18 +43,23 @@ class DNSBLResult:
this dnsbls. dict: {'dnsbl_list_name': list(categories_from_this_dnsbl)}
* categories - set of dnsbl categories from all providers (subset of DNSBL_CATEGORIES)
"""
def __init__(self, addr=None, results=None):

def __init__(
self,
addr: Optional[list[Union[str, bytes]]] = None,
results: Optional[list["DNSBLResponse"]] = None,
):
self.addr = addr
self._results = results
self.blacklisted = False
self.providers = []
self.failed_providers = []
self.detected_by = {}
self.providers: list[Provider] = []
self.failed_providers: list[str] = []
self.detected_by: dict = {}
self.categories = set()
self.process_results()

def process_results(self):
""" Process results by providers """
def process_results(self) -> None:
"""Process results by providers"""
for result in self._results:
provider = result.provider
self.providers.append(provider)
Expand All @@ -63,35 +74,47 @@ def process_results(self):
self.categories = self.categories.union(provider_categories)
self.detected_by[provider.host] = list(provider_categories)

def __repr__(self):
blacklisted = '[BLACKLISTED]' if self.blacklisted else ''
return "<DNSBLResult: %s %s (%d/%d)>" % (self.addr, blacklisted, len(self.detected_by),
len(self.providers))
def __repr__(self) -> str:
blacklisted = "[BLACKLISTED]" if self.blacklisted else ""
return "<DNSBLResult: %s %s (%d/%d)>" % (
self.addr,
blacklisted,
len(self.detected_by),
len(self.providers),
)

class DNSBLResponse:

class DNSBLResponse(NamedTuple):
"""
DNSBL Response object
"""
def __init__(self, addr=None, provider=None, response=None, error=None):
self.addr = addr
self.provider = provider
self.response = response
self.error = error

addr: Optional[Union[str, bytes]] = None
provider: Optional[Provider] = None
response: Optional[list[ares_query_a_result]] = None
error: Optional[aiodns.error.DNSError] = None


class BaseDNSBLChecker(abc.ABC):
""" BASE Checker for DNSBL lists
Arguments:
* providers(list) - list of providers (Provider instance or str)
* timeout(int) - timeout of dns requests will be passed to resolver
* tries(int) - retry times
"""BASE Checker for DNSBL lists
Arguments:
* providers(list) - list of providers (Provider instance or str)
* timeout(int) - timeout of dns requests will be passed to resolver
* tries(int) - retry times
"""

def __init__(self, providers=BASE_PROVIDERS, timeout=5,
tries=2, concurrency=200, loop=None):
def __init__(
self,
providers: list[Provider] = BASE_PROVIDERS,
timeout: float = 5,
tries: int = 2,
concurrency: int = 200,
loop: Optional[asyncio.AbstractEventLoop] = None,
):
self.providers = []
for provider in providers:
if not isinstance(provider, Provider):
raise ValueError('providers should contain only Provider instances')
raise ValueError("providers should contain only Provider instances")
self.providers.append(provider)
if not loop:
if threading.current_thread() == threading.main_thread():
Expand All @@ -101,11 +124,20 @@ def __init__(self, providers=BASE_PROVIDERS, timeout=5,
asyncio.set_event_loop(self._loop)
else:
self._loop = loop
self._resolver = aiodns.DNSResolver(timeout=timeout, tries=tries, loop=self._loop)
self._resolver = aiodns.DNSResolver(
timeout=timeout, tries=tries, loop=self._loop
)
self._semaphore = asyncio.Semaphore(concurrency)

async def __aenter__(self):
return self

async def dnsbl_request(self, request, provider):
async def __aexit__(self, *args):
return await self._resolver.close()

async def dnsbl_request(
self, request: Union[bytes, str], provider: str
) -> DNSBLResponse:
"""
Make lookup to dnsbl provider for ip
Parameters:
Expand All @@ -123,32 +155,34 @@ async def dnsbl_request(self, request, provider):
dnsbl_query = "%s.%s" % (self.prepare_query(request), provider.host)
try:
async with self._semaphore:
response = await self._resolver.query(dnsbl_query, 'A')
response = await self._resolver.query(dnsbl_query, "A")
except aiodns.error.DNSError as exc:
if exc.args[0] != 4: # 4: domain name not found:
if exc.args[0] != 4: # 4: domain name not found:
error = exc

return DNSBLResponse(addr=request, provider=provider, response=response, error=error)
return DNSBLResponse(
addr=request, provider=provider, response=response, error=error
)

@abc.abstractmethod
def prepare_query(self, request):
def prepare_query(self, request: Union[bytes, str]) -> str:
"""
Prepare query to dnsbl
"""
return NotImplemented

async def check_async(self, request):
tasks = []
async def check_async(self, request: Union[bytes, str]):
tasks: list[asyncio.Task[DNSBLResponse]] = []
for provider in self.providers:
tasks.append(self.dnsbl_request(request, provider))
results = await asyncio.gather(*tasks)
return DNSBLResult(addr=request, results=results)

def check(self, request):
def check(self, request: Union[bytes, str]) -> DNSBLResult:
return self._loop.run_until_complete(self.check_async(request))

def bulk_check(self, requests):
tasks = []
def bulk_check(self, requests: Iterable[Union[bytes, str]]) -> list[DNSBLResult]:
tasks: list[asyncio.Task[DNSBLResult]] = []
for request in requests:
tasks.append(self.check_async(request))
return self._loop.run_until_complete(asyncio.gather(*tasks))
Expand All @@ -158,18 +192,17 @@ class DNSBLIpChecker(BaseDNSBLChecker):
"""
Checker for ips
"""
def prepare_query(self, request):

def prepare_query(self, request: Union[bytes, str]) -> str:
address = ipaddress.ip_address(request)
if address.version == 4:
return '.'.join(reversed(request.split('.')))
return ".".join(reversed(request.split(".")))
elif address.version == 6:
# according to RFC: https://tools.ietf.org/html/rfc5782#section-2.4
request_stripped = address.exploded.replace(':', '')
return '.'.join(reversed([x for x in request_stripped]))
request_stripped = address.exploded.replace(":", "")
return ".".join(reversed([x for x in request_stripped]))
else:
raise ValueError('unknown ip version')


raise ValueError("unknown ip version")


class DNSBLDomainChecker(BaseDNSBLChecker):
Expand All @@ -178,33 +211,50 @@ class DNSBLDomainChecker(BaseDNSBLChecker):
"""

# https://regex101.com/r/vdrgm7/1
DOMAIN_REGEX = re.compile(r"^(((?!-))(xn--|_{1,1})?[a-z0-9-]{0,61}[a-z0-9]{1,1}\.)*(xn--[a-z0-9][a-z0-9\-]{0,60}|[a-z0-9-]{1,30}\.[a-z]{2,})$")
DOMAIN_REGEX = re.compile(
r"^(((?!-))(xn--|_{1,1})?[a-z0-9-]{0,61}[a-z0-9]{1,1}\.)*(xn--[a-z0-9][a-z0-9\-]{0,60}|[a-z0-9-]{1,30}\.[a-z]{2,})$"
)

def __init__(self, providers=BASE_DOMAIN_PROVIDERS, timeout=5,
tries=2, concurrency=200, loop=None):
super().__init__(providers=providers, timeout=timeout,
tries=tries, concurrency=concurrency, loop=loop)
def __init__(
self,
providers: list[Provider] = BASE_DOMAIN_PROVIDERS,
timeout: float = 5,
tries: int = 2,
concurrency: int = 200,
loop: Optional[asyncio.AbstractEventLoop] = None,
):
super().__init__(
providers=providers,
timeout=timeout,
tries=tries,
concurrency=concurrency,
loop=loop,
)

def prepare_query(self, request):
request = request.lower() # Adding support for capitalized letters in domain name.
def prepare_query(self, request: Union[str, bytes]) -> str:
request = (
request.lower()
) # Adding support for capitalized letters in domain name.
domain_idna = idna.encode(request).decode()
if not self.DOMAIN_REGEX.match(domain_idna):
raise ValueError('should be valid domain, got %s' % domain_idna)
raise ValueError("should be valid domain, got %s" % domain_idna)
return domain_idna


# COMPAT
class DNSBLChecker(DNSBLIpChecker):
"""
Will be deprecated, use DNSBLIpChecker
"""

def __init__(self, *args, **kwargs):
warnings.warn('deprecated, use DNSBLIpChecker', DeprecationWarning)
warnings.warn("deprecated, use DNSBLIpChecker", DeprecationWarning)
super().__init__(*args, **kwargs)

def check_ip(self, addr):
warnings.warn('deprecated, use check method instead', DeprecationWarning)
warnings.warn("deprecated, use check method instead", DeprecationWarning)
return self.check(addr)

def check_ips(self, addrs):
warnings.warn('deprecated, use bulk check method instead', DeprecationWarning)
warnings.warn("deprecated, use bulk check method instead", DeprecationWarning)
return self.bulk_check(addrs)
Loading
Loading