Skip to content

Commit 4ad0f43

Browse files
authored
Merge pull request aws#8336 from kdaily/ssm-session-manager-pluging-env-variable
Pass StartSession response as env variable
2 parents a453709 + 0d5e0c1 commit 4ad0f43

File tree

4 files changed

+441
-24
lines changed

4 files changed

+441
-24
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"type": "enhancement",
3+
"category": "``ssm`` Session Manager",
4+
"description": "Pass StartSession API response as environment variable to session-manager-plugin"
5+
}

awscli/customizations/sessionmanager.py

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
import logging
1414
import json
1515
import errno
16+
import os
17+
import re
1618

17-
from subprocess import check_call
19+
from subprocess import check_call, check_output
1820
from awscli.compat import ignore_user_entered_signals
1921
from awscli.clidriver import ServiceOperation, CLIOperationCaller
2022

@@ -44,8 +46,43 @@ def add_custom_start_session(session, command_table, **kwargs):
4446
)
4547

4648

47-
class StartSessionCommand(ServiceOperation):
49+
class VersionRequirement:
50+
WHITESPACE_REGEX = re.compile(r"\s+")
51+
SSM_SESSION_PLUGIN_VERSION_REGEX = re.compile(r"^\d+(\.\d+){0,3}$")
52+
53+
def __init__(self, min_version):
54+
self.min_version = min_version
55+
56+
def meets_requirement(self, version):
57+
ssm_plugin_version = self._sanitize_plugin_version(version)
58+
if self._is_valid_version(ssm_plugin_version):
59+
norm_version, norm_min_version = self._normalize(
60+
ssm_plugin_version, self.min_version
61+
)
62+
return norm_version > norm_min_version
63+
else:
64+
return False
65+
66+
def _sanitize_plugin_version(self, plugin_version):
67+
return re.sub(self.WHITESPACE_REGEX, "", plugin_version)
68+
69+
def _is_valid_version(self, plugin_version):
70+
return bool(
71+
self.SSM_SESSION_PLUGIN_VERSION_REGEX.match(plugin_version)
72+
)
73+
74+
def _normalize(self, v1, v2):
75+
v1_parts = [int(v) for v in v1.split(".")]
76+
v2_parts = [int(v) for v in v2.split(".")]
77+
while len(v1_parts) != len(v2_parts):
78+
if len(v1_parts) - len(v2_parts) > 0:
79+
v2_parts.append(0)
80+
else:
81+
v1_parts.append(0)
82+
return v1_parts, v2_parts
4883

84+
85+
class StartSessionCommand(ServiceOperation):
4986
def create_help_command(self):
5087
help_command = super(
5188
StartSessionCommand, self).create_help_command()
@@ -55,6 +92,9 @@ def create_help_command(self):
5592

5693

5794
class StartSessionCaller(CLIOperationCaller):
95+
LAST_PLUGIN_VERSION_WITHOUT_ENV_VAR = "1.2.497.0"
96+
DEFAULT_SSM_ENV_NAME = "AWS_SSM_START_SESSION_RESPONSE"
97+
5898
def invoke(self, service_name, operation_name, parameters,
5999
parsed_globals):
60100
client = self._session.create_client(
@@ -70,8 +110,34 @@ def invoke(self, service_name, operation_name, parameters,
70110
profile_name = self._session.profile \
71111
if self._session.profile is not None else ''
72112
endpoint_url = client.meta.endpoint_url
113+
ssm_env_name = self.DEFAULT_SSM_ENV_NAME
73114

74115
try:
116+
session_parameters = {
117+
"SessionId": response["SessionId"],
118+
"TokenValue": response["TokenValue"],
119+
"StreamUrl": response["StreamUrl"],
120+
}
121+
start_session_response = json.dumps(session_parameters)
122+
123+
plugin_version = check_output(
124+
["session-manager-plugin", "--version"], text=True
125+
)
126+
env = os.environ.copy()
127+
128+
# Check if this plugin supports passing the start session response
129+
# as an environment variable name. If it does, it will set the
130+
# value to the response from the start_session operation to the env
131+
# variable defined in DEFAULT_SSM_ENV_NAME. If the session plugin
132+
# version is invalid or older than the version defined in
133+
# LAST_PLUGIN_VERSION_WITHOUT_ENV_VAR, it will fall back to
134+
# passing the start_session response directly.
135+
version_requirement = VersionRequirement(
136+
min_version=self.LAST_PLUGIN_VERSION_WITHOUT_ENV_VAR
137+
)
138+
if version_requirement.meets_requirement(plugin_version):
139+
env[ssm_env_name] = start_session_response
140+
start_session_response = ssm_env_name
75141
# ignore_user_entered_signals ignores these signals
76142
# because if signals which kills the process are not
77143
# captured would kill the foreground process but not the
@@ -81,12 +147,13 @@ def invoke(self, service_name, operation_name, parameters,
81147
with ignore_user_entered_signals():
82148
# call executable with necessary input
83149
check_call(["session-manager-plugin",
84-
json.dumps(response),
150+
start_session_response,
85151
region_name,
86152
"StartSession",
87153
profile_name,
88154
json.dumps(parameters),
89-
endpoint_url])
155+
endpoint_url], env=env)
156+
90157
return 0
91158
except OSError as ex:
92159
if ex.errno == errno.ENOENT:

tests/functional/ssm/test_start_session.py

Lines changed: 98 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,38 +15,119 @@
1515

1616
from awscli.testutils import BaseAWSCommandParamsTest
1717
from awscli.testutils import BaseAWSHelpOutputTest
18-
from awscli.testutils import mock
18+
from awscli.testutils import mock
1919

20-
class TestSessionManager(BaseAWSCommandParamsTest):
2120

21+
class TestSessionManager(BaseAWSCommandParamsTest):
2222
@mock.patch('awscli.customizations.sessionmanager.check_call')
23-
def test_start_session_success(self, mock_check_call):
23+
@mock.patch("awscli.customizations.sessionmanager.check_output")
24+
def test_start_session_success(self, mock_check_output, mock_check_call):
2425
cmdline = 'ssm start-session --target instance-id'
2526
mock_check_call.return_value = 0
26-
self.parsed_responses = [{
27+
mock_check_output.return_value = "1.2.0.0\n"
28+
expected_response = {
2729
"SessionId": "session-id",
2830
"TokenValue": "token-value",
29-
"StreamUrl": "stream-url"
30-
}]
31+
"StreamUrl": "stream-url",
32+
}
33+
self.parsed_responses = [expected_response]
34+
start_session_params = {"Target": "instance-id"}
35+
36+
self.run_cmd(cmdline, expected_rc=0)
37+
38+
mock_check_call.assert_called_once_with(
39+
[
40+
"session-manager-plugin",
41+
json.dumps(expected_response),
42+
mock.ANY,
43+
"StartSession",
44+
mock.ANY,
45+
json.dumps(start_session_params),
46+
mock.ANY,
47+
],
48+
env=self.environ,
49+
)
50+
51+
@mock.patch("awscli.customizations.sessionmanager.check_call")
52+
@mock.patch("awscli.customizations.sessionmanager.check_output")
53+
def test_start_session_with_new_version_plugin_success(
54+
self, mock_check_output, mock_check_call
55+
):
56+
cmdline = "ssm start-session --target instance-id"
57+
mock_check_call.return_value = 0
58+
mock_check_output.return_value = "1.2.500.0\n"
59+
expected_response = {
60+
"SessionId": "session-id",
61+
"TokenValue": "token-value",
62+
"StreamUrl": "stream-url",
63+
}
64+
self.parsed_responses = [expected_response]
65+
66+
ssm_env_name = "AWS_SSM_START_SESSION_RESPONSE"
67+
start_session_params = {"Target": "instance-id"}
68+
expected_env = self.environ.copy()
69+
expected_env.update({ssm_env_name: json.dumps(expected_response)})
70+
3171
self.run_cmd(cmdline, expected_rc=0)
3272
self.assertEqual(self.operations_called[0][0].name,
3373
'StartSession')
3474
self.assertEqual(self.operations_called[0][1],
3575
{'Target': 'instance-id'})
36-
actual_response = json.loads(mock_check_call.call_args[0][0][1])
37-
self.assertEqual(
38-
{"SessionId": "session-id",
39-
"TokenValue": "token-value",
40-
"StreamUrl": "stream-url"},
41-
actual_response)
76+
77+
mock_check_call.assert_called_once_with(
78+
[
79+
"session-manager-plugin",
80+
ssm_env_name,
81+
mock.ANY,
82+
"StartSession",
83+
mock.ANY,
84+
json.dumps(start_session_params),
85+
mock.ANY,
86+
],
87+
env=expected_env,
88+
)
4289

4390
@mock.patch('awscli.customizations.sessionmanager.check_call')
44-
def test_start_session_fails(self, mock_check_call):
91+
@mock.patch("awscli.customizations.sessionmanager.check_output")
92+
def test_start_session_fails(self, mock_check_output, mock_check_call):
93+
cmdline = "ssm start-session --target instance-id"
94+
mock_check_output.return_value = "1.2.500.0\n"
95+
mock_check_call.side_effect = OSError(errno.ENOENT, "some error")
96+
self.parsed_responses = [
97+
{
98+
"SessionId": "session-id",
99+
"TokenValue": "token-value",
100+
"StreamUrl": "stream-url",
101+
}
102+
]
103+
self.run_cmd(cmdline, expected_rc=255)
104+
self.assertEqual(
105+
self.operations_called[0][0].name, "StartSession"
106+
)
107+
self.assertEqual(
108+
self.operations_called[0][1], {"Target": "instance-id"}
109+
)
110+
self.assertEqual(
111+
self.operations_called[1][0].name, "TerminateSession"
112+
)
113+
self.assertEqual(
114+
self.operations_called[1][1], {"SessionId": "session-id"}
115+
)
116+
117+
@mock.patch("awscli.customizations.sessionmanager.check_call")
118+
@mock.patch("awscli.customizations.sessionmanager.check_output")
119+
def test_start_session_when_get_plugin_version_fails(
120+
self, mock_check_output, mock_check_call
121+
):
45122
cmdline = 'ssm start-session --target instance-id'
46-
mock_check_call.side_effect = OSError(errno.ENOENT, 'some error')
47-
self.parsed_responses = [{
48-
"SessionId": "session-id"
49-
}]
123+
mock_check_output.side_effect = OSError(errno.ENOENT, 'some error')
124+
self.parsed_responses = [
125+
{
126+
"SessionId": "session-id",
127+
"TokenValue": "token-value",
128+
"StreamUrl": "stream-url",
129+
}
130+
]
50131
self.run_cmd(cmdline, expected_rc=255)
51132
self.assertEqual(self.operations_called[0][0].name,
52133
'StartSession')

0 commit comments

Comments
 (0)