Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
## [Unreleased]

## Added

- Support using alternative documents and parameters with
`SsmPortForwardingSession`
([#22](https://github.com/ackama/aws_ec2_environment/pull/22))

## [0.1.0] - 2022-08-17

- Initial release
38 changes: 38 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,44 @@ task :forward_port, %i[instance_id remote_port local_port] => :environment do |_
end
```

You can also use specific documents, and pass in extra parameters, which can be
useful for using tunnels to access other private resources like database
instances:

```ruby
require "aws_ec2_environment"

desc "Dumps a copy of the postgres database using AWS and PG environment variables"
task :dump_pg_database, %i[instance_id dump_file] => :environment do |_, args|
instance_id = args.fetch(:instance_id)
dump_file = args.fetch(:dump_file)

remote_host = ENV.fetch("PGHOST")
remote_port = ENV.fetch("PGPORT", 5432)

session = AwsEc2Environment::SsmPortForwardingSession.new(
instance_id,
remote_port,
document: "AWS-StartPortForwardingSessionToRemoteHost",
extra_params: { "host" => [remote_host] }
)

at_exit { session.close }

local_port = session.wait_for_local_port

system([
"pg_dump",
"--format=c",
"--no-owner",
"--no-privileges",
"--host=localhost",
"--port=#{local_port}",
"--file=#{dump_file}",
].join(" "))
end
```

### AWS Authentication and Permissions

Since this gem interacts with AWS, it must be configured with credentials - see
Expand Down
10 changes: 5 additions & 5 deletions lib/aws_ec2_environment/ssm_port_forwarding_session.rb
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ class SessionProcessError < Error; end
# rubocop:disable Metrics/ParameterLists
def initialize(
instance_id, remote_port,
document: "AWS-StartPortForwardingSession",
local_port: nil, logger: Logger.new($stdout),
timeout: 15, reason: nil
timeout: 15, reason: nil, extra_params: {}
)
# rubocop:enable Metrics/ParameterLists
@logger = logger
Expand All @@ -32,7 +33,7 @@ def initialize(
@local_port = nil
@timeout = timeout

@reader, @writer, @pid = PTY.spawn(ssm_port_forward_cmd(local_port, reason))
@reader, @writer, @pid = PTY.spawn(ssm_port_forward_cmd(local_port, reason, document, extra_params))

@cmd_output = ""
@session_id = wait_for_session_id
Expand Down Expand Up @@ -64,9 +65,8 @@ def wait_for_local_port

private

def ssm_port_forward_cmd(local_port, reason)
document_name = "AWS-StartPortForwardingSession"
parameters = { "portNumber" => [remote_port.to_s] }
def ssm_port_forward_cmd(local_port, reason, document_name, extra_parameters)
parameters = extra_parameters.merge({ "portNumber" => [remote_port.to_s] })
parameters["localPortNumber"] = [local_port.to_s] unless local_port.nil?
flags = [
["--target", instance_id],
Expand Down
6 changes: 4 additions & 2 deletions sig/aws_ec2_environment/ssm_port_forwarding_session.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ class AwsEc2Environment
def initialize: (
String instance_id,
Integer remote_port,
?document: String,
?local_port: Integer | nil,
?logger: Logger,
?timeout: Numeric,
?reason: String | nil
?reason: String | nil,
?extra_params: Hash[String, untyped]
) -> void

def close: () -> void
Expand All @@ -38,7 +40,7 @@ class AwsEc2Environment
@writer: IO
@cmd_output: String

def ssm_port_forward_cmd: (Integer | nil local_port, String | nil reason) -> String
def ssm_port_forward_cmd: (Integer | nil local_port, String | nil reason, String document_name, Hash[String, untyped] extra_params) -> String

# Checks the cmd process output until either the given +pattern+ matches or the +timeout+ is over.
#
Expand Down
92 changes: 92 additions & 0 deletions spec/aws_ec2_environment/ssm_port_forwarding_session_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,98 @@ def close; end
)
end
end

context "when a specific document is provided" do
subject(:session) do
described_class.new(
"i-0d9c4bg3f26157a8e",
22,
document: "AWS-StartPortForwardingSessionToRemoteHost",
logger: Logger.new(StringIO.new(log)),
# we can use a really low timeout to make the tests a lot faster,
# since we're not actually going to be writing asynchronously
timeout: 0.00001
)
end

it "uses the document" do
expect { session }.not_to raise_error

parameters = { "portNumber" => ["22"] }
parameters_escaped = Shellwords.escape(parameters.to_json)

expect(PTY).to have_received(:spawn).with(
%w[
aws ssm start-session
--target i-0d9c4bg3f26157a8e
--document-name AWS-StartPortForwardingSessionToRemoteHost
--parameters
].join(" ") + " #{parameters_escaped}"
)
end
end

context "when extra parameters are provided" do
it "merges them" do
expect do
described_class.new(
"i-0d9c4bg3f26157a8e",
22,
extra_params: { "host" => ["my-database.abc.ap-southeast-2.rds.amazonaws.com"] },
logger: Logger.new(StringIO.new(log)),
# we can use a really low timeout to make the tests a lot faster,
# since we're not actually going to be writing asynchronously
timeout: 0.00001
)
end.not_to raise_error

parameters = { "host" => ["my-database.abc.ap-southeast-2.rds.amazonaws.com"], "portNumber" => ["22"] }
parameters_escaped = Shellwords.escape(parameters.to_json)

expect(PTY).to have_received(:spawn).with(
%w[
aws ssm start-session
--target i-0d9c4bg3f26157a8e
--document-name AWS-StartPortForwardingSession
--parameters
].join(" ") + " #{parameters_escaped}"
)
end

it "overrides them with specific parameters" do
expect do
described_class.new(
"i-0d9c4bg3f26157a8e",
22,
extra_params: {
"host" => ["my-database.abc.ap-southeast-2.rds.amazonaws.com"],
"localPortNumber" => [1234]
},
local_port: 5432,
logger: Logger.new(StringIO.new(log)),
# we can use a really low timeout to make the tests a lot faster,
# since we're not actually going to be writing asynchronously
timeout: 0.00001
)
end.not_to raise_error

parameters = {
"host" => ["my-database.abc.ap-southeast-2.rds.amazonaws.com"],
"localPortNumber" => ["5432"],
"portNumber" => ["22"]
}
parameters_escaped = Shellwords.escape(parameters.to_json)

expect(PTY).to have_received(:spawn).with(
%w[
aws ssm start-session
--target i-0d9c4bg3f26157a8e
--document-name AWS-StartPortForwardingSession
--parameters
].join(" ") + " #{parameters_escaped}"
)
end
end
end

describe "#instance_id" do
Expand Down