Skip to content

Commit af6e103

Browse files
committed
MSK IAM Authentication implementation
1 parent 5e508ed commit af6e103

File tree

2 files changed

+246
-1
lines changed

2 files changed

+246
-1
lines changed

kafka/conn.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import kafka.errors as Errors
2626
from kafka.future import Future
2727
from kafka.metrics.stats import Avg, Count, Max, Rate
28+
from kafka.msk import AwsMskIamClient
2829
from kafka.oauth.abstract import AbstractTokenProvider
2930
from kafka.protocol.admin import SaslHandShakeRequest
3031
from kafka.protocol.commit import OffsetFetchRequest
@@ -81,6 +82,12 @@ class SSLWantWriteError(Exception):
8182
gssapi = None
8283
GSSError = None
8384

85+
# needed for AWS_MSK_IAM authentication:
86+
try:
87+
from botocore.session import Session as BotoSession
88+
except ImportError:
89+
# no botocore available, will disable AWS_MSK_IAM mechanism
90+
BotoSession = None
8491

8592
AFI_NAMES = {
8693
socket.AF_UNSPEC: "unspecified",
@@ -224,7 +231,7 @@ class BrokerConnection(object):
224231
'sasl_oauth_token_provider': None
225232
}
226233
SECURITY_PROTOCOLS = ('PLAINTEXT', 'SSL', 'SASL_PLAINTEXT', 'SASL_SSL')
227-
SASL_MECHANISMS = ('PLAIN', 'GSSAPI', 'OAUTHBEARER')
234+
SASL_MECHANISMS = ('PLAIN', 'GSSAPI', 'OAUTHBEARER', 'AWS_MSK_IAM')
228235

229236
def __init__(self, host, port, afi, **configs):
230237
self.host = host
@@ -269,6 +276,11 @@ def __init__(self, host, port, afi, **configs):
269276
token_provider = self.config['sasl_oauth_token_provider']
270277
assert token_provider is not None, 'sasl_oauth_token_provider required for OAUTHBEARER sasl'
271278
assert callable(getattr(token_provider, "token", None)), 'sasl_oauth_token_provider must implement method #token()'
279+
280+
if self.config['sasl_mechanism'] == 'AWS_MSK_IAM':
281+
assert BotoSession is not None, 'AWS_MSK_IAM requires the "botocore" package'
282+
assert self.config['security_protocol'] == 'SASL_SSL', 'AWS_MSK_IAM requires SASL_SSL'
283+
272284
# This is not a general lock / this class is not generally thread-safe yet
273285
# However, to avoid pushing responsibility for maintaining
274286
# per-connection locks to the upstream client, we will use this lock to
@@ -552,6 +564,8 @@ def _handle_sasl_handshake_response(self, future, response):
552564
return self._try_authenticate_gssapi(future)
553565
elif self.config['sasl_mechanism'] == 'OAUTHBEARER':
554566
return self._try_authenticate_oauth(future)
567+
elif self.config['sasl_mechanism'] == 'AWS_MSK_IAM':
568+
return self._try_authenticate_aws_msk_iam(future)
555569
else:
556570
return future.failure(
557571
Errors.UnsupportedSaslMechanismError(
@@ -652,6 +666,40 @@ def _try_authenticate_plain(self, future):
652666
log.info('%s: Authenticated as %s via PLAIN', self, self.config['sasl_plain_username'])
653667
return future.success(True)
654668

669+
def _try_authenticate_aws_msk_iam(self, future):
670+
session = BotoSession()
671+
client = AwsMskIamClient(
672+
host=self.host,
673+
boto_session=session,
674+
)
675+
676+
msg = client.first_message()
677+
size = Int32.encode(len(msg))
678+
679+
err = None
680+
close = False
681+
with self._lock:
682+
if not self._can_send_recv():
683+
err = Errors.NodeNotReadyError(str(self))
684+
close = False
685+
else:
686+
try:
687+
self._send_bytes_blocking(size + msg)
688+
data = self._recv_bytes_blocking(4)
689+
data = self._recv_bytes_blocking(struct.unpack('4B', data)[-1])
690+
except (ConnectionError, TimeoutError) as e:
691+
log.exception("%s: Error receiving reply from server", self)
692+
err = Errors.KafkaConnectionError("%s: %s" % (self, e))
693+
close = True
694+
695+
if err is not None:
696+
if close:
697+
self.close(error=err)
698+
return future.failure(err)
699+
700+
log.info('%s: Authenticated via AWS_MSK_IAM %s', self, data.decode('utf-8'))
701+
return future.success(True)
702+
655703
def _try_authenticate_gssapi(self, future):
656704
kerberos_damin_name = self.config['sasl_kerberos_domain_name'] or self.host
657705
auth_id = self.config['sasl_kerberos_service_name'] + '@' + kerberos_damin_name

kafka/msk.py

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
import datetime
2+
import hashlib
3+
import hmac
4+
import json
5+
import string
6+
7+
from kafka.vendor.six.moves import urllib
8+
9+
10+
class AwsMskIamClient:
11+
UNRESERVED_CHARS = string.ascii_letters + string.digits + '-._~'
12+
13+
def __init__(self, host, boto_session):
14+
"""
15+
Arguments:
16+
host (str): The hostname of the broker.
17+
boto_session (botocore.BotoSession) the boto session
18+
"""
19+
self.algorithm = 'AWS4-HMAC-SHA256'
20+
self.expires = '900'
21+
self.hashfunc = hashlib.sha256
22+
self.headers = [
23+
('host', host)
24+
]
25+
self.version = '2020_10_22'
26+
27+
self.service = 'kafka-cluster'
28+
self.action = '{}:Connect'.format(self.service)
29+
30+
now = datetime.datetime.utcnow()
31+
self.datestamp = now.strftime('%Y%m%d')
32+
self.timestamp = now.strftime('%Y%m%dT%H%M%SZ')
33+
34+
self.host = host
35+
self.boto_session = boto_session
36+
37+
@property
38+
def access_key(self):
39+
return self.boto_session.get_credentials().access_key
40+
41+
@property
42+
def secret_key(self):
43+
return self.boto_session.get_credentials().secret_key
44+
45+
@property
46+
def token(self):
47+
return self.boto_session.get_credentials().token
48+
49+
@property
50+
def region(self):
51+
# TODO: This logic is not perfect and should be revisited
52+
for host in self.host.split(','):
53+
if 'amazonaws.com' in host:
54+
return host.split('.')[-3]
55+
return 'us-west-2'
56+
57+
@property
58+
def _credential(self):
59+
return '{0.access_key}/{0._scope}'.format(self)
60+
61+
@property
62+
def _scope(self):
63+
return '{0.datestamp}/{0.region}/{0.service}/aws4_request'.format(self)
64+
65+
@property
66+
def _signed_headers(self):
67+
"""
68+
Returns (str):
69+
An alphabetically sorted, semicolon-delimited list of lowercase
70+
request header names.
71+
"""
72+
return ';'.join(sorted(k.lower() for k, _ in self.headers))
73+
74+
@property
75+
def _canonical_headers(self):
76+
"""
77+
Returns (str):
78+
A newline-delited list of header names and values.
79+
Header names are lowercased.
80+
"""
81+
return '\n'.join(map(':'.join, self.headers)) + '\n'
82+
83+
@property
84+
def _canonical_request(self):
85+
"""
86+
Returns (str):
87+
An AWS Signature Version 4 canonical request in the format:
88+
<Method>\n
89+
<Path>\n
90+
<CanonicalQueryString>\n
91+
<CanonicalHeaders>\n
92+
<SignedHeaders>\n
93+
<HashedPayload>
94+
"""
95+
# The hashed_payload is always an empty string for MSK.
96+
hashed_payload = self.hashfunc(b'').hexdigest()
97+
return '\n'.join((
98+
'GET',
99+
'/',
100+
self._canonical_querystring,
101+
self._canonical_headers,
102+
self._signed_headers,
103+
hashed_payload,
104+
))
105+
106+
@property
107+
def _canonical_querystring(self):
108+
"""
109+
Returns (str):
110+
A '&'-separated list of URI-encoded key/value pairs.
111+
"""
112+
params = []
113+
params.append(('Action', self.action))
114+
params.append(('X-Amz-Algorithm', self.algorithm))
115+
params.append(('X-Amz-Credential', self._credential))
116+
params.append(('X-Amz-Date', self.timestamp))
117+
params.append(('X-Amz-Expires', self.expires))
118+
if self.token:
119+
params.append(('X-Amz-Security-Token', self.token))
120+
params.append(('X-Amz-SignedHeaders', self._signed_headers))
121+
122+
return '&'.join(self._uriencode(k) + '=' + self._uriencode(v) for k, v in params)
123+
124+
@property
125+
def _signing_key(self):
126+
"""
127+
Returns (bytes):
128+
An AWS Signature V4 signing key generated from the secret_key, date,
129+
region, service, and request type.
130+
"""
131+
key = self._hmac(('AWS4' + self.secret_key).encode('utf-8'), self.datestamp)
132+
key = self._hmac(key, self.region)
133+
key = self._hmac(key, self.service)
134+
key = self._hmac(key, 'aws4_request')
135+
return key
136+
137+
@property
138+
def _signing_str(self):
139+
"""
140+
Returns (str):
141+
A string used to sign the AWS Signature V4 payload in the format:
142+
<Algorithm>\n
143+
<Timestamp>\n
144+
<Scope>\n
145+
<CanonicalRequestHash>
146+
"""
147+
canonical_request_hash = self.hashfunc(self._canonical_request.encode('utf-8')).hexdigest()
148+
return '\n'.join((self.algorithm, self.timestamp, self._scope, canonical_request_hash))
149+
150+
def _uriencode(self, msg):
151+
"""
152+
Arguments:
153+
msg (str): A string to URI-encode.
154+
155+
Returns (str):
156+
The URI-encoded version of the provided msg, following the encoding
157+
rules specified: https://github.com/aws/aws-msk-iam-auth#uriencode
158+
"""
159+
return urllib.parse.quote(msg, safe=self.UNRESERVED_CHARS)
160+
161+
def _hmac(self, key, msg):
162+
"""
163+
Arguments:
164+
key (bytes): A key to use for the HMAC digest.
165+
msg (str): A value to include in the HMAC digest.
166+
Returns (bytes):
167+
An HMAC digest of the given key and msg.
168+
"""
169+
return hmac.new(key, msg.encode('utf-8'), digestmod=self.hashfunc).digest()
170+
171+
def first_message(self):
172+
"""
173+
Returns (bytes):
174+
An encoded JSON authentication payload that can be sent to the
175+
broker.
176+
"""
177+
signature = hmac.new(
178+
self._signing_key,
179+
self._signing_str.encode('utf-8'),
180+
digestmod=self.hashfunc,
181+
).hexdigest()
182+
msg = {
183+
'version': self.version,
184+
'host': self.host,
185+
'user-agent': 'kafka-python',
186+
'action': self.action,
187+
'x-amz-algorithm': self.algorithm,
188+
'x-amz-credential': self._credential,
189+
'x-amz-date': self.timestamp,
190+
'x-amz-signedheaders': self._signed_headers,
191+
'x-amz-expires': self.expires,
192+
'x-amz-signature': signature,
193+
}
194+
if self.token:
195+
msg['x-amz-security-token'] = self.token
196+
197+
return json.dumps(msg, separators=(',', ':')).encode('utf-8')

0 commit comments

Comments
 (0)