|
| 1 | +# This is a direct lift from |
| 2 | +# https://github.com/twisted/twisted/blob/release-21.2.0-10091/src/twisted/internet/_resolver.py. |
| 3 | +# We copy it here as we need to instantiate `GAIResolver` manually, but it is a |
| 4 | +# private class. |
| 5 | + |
| 6 | + |
| 7 | +from socket import ( |
| 8 | + AF_INET, |
| 9 | + AF_INET6, |
| 10 | + AF_UNSPEC, |
| 11 | + SOCK_DGRAM, |
| 12 | + SOCK_STREAM, |
| 13 | + gaierror, |
| 14 | + getaddrinfo, |
| 15 | +) |
| 16 | + |
| 17 | +from zope.interface import implementer |
| 18 | + |
| 19 | +from twisted.internet.address import IPv4Address, IPv6Address |
| 20 | +from twisted.internet.interfaces import IHostnameResolver, IHostResolution |
| 21 | +from twisted.internet.threads import deferToThreadPool |
| 22 | + |
| 23 | + |
| 24 | +@implementer(IHostResolution) |
| 25 | +class HostResolution: |
| 26 | + """ |
| 27 | + The in-progress resolution of a given hostname. |
| 28 | + """ |
| 29 | + |
| 30 | + def __init__(self, name): |
| 31 | + """ |
| 32 | + Create a L{HostResolution} with the given name. |
| 33 | + """ |
| 34 | + self.name = name |
| 35 | + |
| 36 | + def cancel(self): |
| 37 | + # IHostResolution.cancel |
| 38 | + raise NotImplementedError() |
| 39 | + |
| 40 | + |
| 41 | +_any = frozenset([IPv4Address, IPv6Address]) |
| 42 | + |
| 43 | +_typesToAF = { |
| 44 | + frozenset([IPv4Address]): AF_INET, |
| 45 | + frozenset([IPv6Address]): AF_INET6, |
| 46 | + _any: AF_UNSPEC, |
| 47 | +} |
| 48 | + |
| 49 | +_afToType = { |
| 50 | + AF_INET: IPv4Address, |
| 51 | + AF_INET6: IPv6Address, |
| 52 | +} |
| 53 | + |
| 54 | +_transportToSocket = { |
| 55 | + "TCP": SOCK_STREAM, |
| 56 | + "UDP": SOCK_DGRAM, |
| 57 | +} |
| 58 | + |
| 59 | +_socktypeToType = { |
| 60 | + SOCK_STREAM: "TCP", |
| 61 | + SOCK_DGRAM: "UDP", |
| 62 | +} |
| 63 | + |
| 64 | + |
| 65 | +@implementer(IHostnameResolver) |
| 66 | +class GAIResolver: |
| 67 | + """ |
| 68 | + L{IHostnameResolver} implementation that resolves hostnames by calling |
| 69 | + L{getaddrinfo} in a thread. |
| 70 | + """ |
| 71 | + |
| 72 | + def __init__(self, reactor, getThreadPool=None, getaddrinfo=getaddrinfo): |
| 73 | + """ |
| 74 | + Create a L{GAIResolver}. |
| 75 | + @param reactor: the reactor to schedule result-delivery on |
| 76 | + @type reactor: L{IReactorThreads} |
| 77 | + @param getThreadPool: a function to retrieve the thread pool to use for |
| 78 | + scheduling name resolutions. If not supplied, the use the given |
| 79 | + C{reactor}'s thread pool. |
| 80 | + @type getThreadPool: 0-argument callable returning a |
| 81 | + L{twisted.python.threadpool.ThreadPool} |
| 82 | + @param getaddrinfo: a reference to the L{getaddrinfo} to use - mainly |
| 83 | + parameterized for testing. |
| 84 | + @type getaddrinfo: callable with the same signature as L{getaddrinfo} |
| 85 | + """ |
| 86 | + self._reactor = reactor |
| 87 | + self._getThreadPool = ( |
| 88 | + reactor.getThreadPool if getThreadPool is None else getThreadPool |
| 89 | + ) |
| 90 | + self._getaddrinfo = getaddrinfo |
| 91 | + |
| 92 | + def resolveHostName( |
| 93 | + self, |
| 94 | + resolutionReceiver, |
| 95 | + hostName, |
| 96 | + portNumber=0, |
| 97 | + addressTypes=None, |
| 98 | + transportSemantics="TCP", |
| 99 | + ): |
| 100 | + """ |
| 101 | + See L{IHostnameResolver.resolveHostName} |
| 102 | + @param resolutionReceiver: see interface |
| 103 | + @param hostName: see interface |
| 104 | + @param portNumber: see interface |
| 105 | + @param addressTypes: see interface |
| 106 | + @param transportSemantics: see interface |
| 107 | + @return: see interface |
| 108 | + """ |
| 109 | + pool = self._getThreadPool() |
| 110 | + addressFamily = _typesToAF[ |
| 111 | + _any if addressTypes is None else frozenset(addressTypes) |
| 112 | + ] |
| 113 | + socketType = _transportToSocket[transportSemantics] |
| 114 | + |
| 115 | + def get(): |
| 116 | + try: |
| 117 | + return self._getaddrinfo( |
| 118 | + hostName, portNumber, addressFamily, socketType |
| 119 | + ) |
| 120 | + except gaierror: |
| 121 | + return [] |
| 122 | + |
| 123 | + d = deferToThreadPool(self._reactor, pool, get) |
| 124 | + resolution = HostResolution(hostName) |
| 125 | + resolutionReceiver.resolutionBegan(resolution) |
| 126 | + |
| 127 | + @d.addCallback |
| 128 | + def deliverResults(result): |
| 129 | + for family, socktype, _proto, _cannoname, sockaddr in result: |
| 130 | + addrType = _afToType[family] |
| 131 | + resolutionReceiver.addressResolved( |
| 132 | + addrType(_socktypeToType.get(socktype, "TCP"), *sockaddr) |
| 133 | + ) |
| 134 | + resolutionReceiver.resolutionComplete() |
| 135 | + |
| 136 | + return resolution |
0 commit comments