Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions internal/aws/ssm_remote_access_strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type SSMRemoteAccessClientInterface interface {

// PodExecInterface defines the interface for executing commands in pods.
type PodExecInterface interface {
ExecInPod(ctx context.Context, pod *corev1.Pod, containerName string, cmd []string) (string, error)
ExecInPod(ctx context.Context, pod *corev1.Pod, containerName string, cmd []string, stdin string) (string, error)
}

// Constants for SSM remote access strategy
Expand Down Expand Up @@ -133,12 +133,13 @@ func (s *SSMRemoteAccessStrategy) CleanupSSMManagedNodes(ctx context.Context, po
// isSSMRegistrationCompleted checks if SSM registration is already done for this pod
func (s *SSMRemoteAccessStrategy) isSSMRegistrationCompleted(ctx context.Context, pod *corev1.Pod) bool {
logger := logf.FromContext(ctx).WithValues("pod", pod.Name)
noStdin := "" // For commands that don't need stdin input

// TODO: improve race condition handling for rapid pod events

// Check for completion marker file in sidecar
cmd := []string{"test", "-f", SSMRegistrationMarkerFile}
_, err := s.podExecUtil.ExecInPod(ctx, pod, SSMAgentSidecarContainerName, cmd)
_, err := s.podExecUtil.ExecInPod(ctx, pod, SSMAgentSidecarContainerName, cmd, noStdin)

completed := err == nil
logger.V(2).Info("SSM registration completion check", "completed", completed)
Expand All @@ -148,6 +149,7 @@ func (s *SSMRemoteAccessStrategy) isSSMRegistrationCompleted(ctx context.Context
// performSSMRegistration handles the SSM activation and registration process
func (s *SSMRemoteAccessStrategy) performSSMRegistration(ctx context.Context, pod *corev1.Pod, workspace *workspacev1alpha1.Workspace, accessStrategy *workspacev1alpha1.WorkspaceAccessStrategy) error {
logger := logf.FromContext(ctx).WithValues("pod", pod.Name, "workspace", workspace.Name)
noStdin := "" // For commands that don't need stdin input

if s.ssmClient == nil {
return fmt.Errorf("SSM client not available")
Expand All @@ -160,25 +162,29 @@ func (s *SSMRemoteAccessStrategy) performSSMRegistration(ctx context.Context, po
return fmt.Errorf("failed to create SSM activation: %w", err)
}

// Step 2: Run register-ssm.sh with environment variables
// Step 2: Run register-ssm.sh with sensitive values passed via stdin
logger.Info("Running SSM registration script in sidecar")
region := s.ssmClient.GetRegion()
cmd := []string{"bash", "-c", fmt.Sprintf("env ACTIVATION_ID=%s ACTIVATION_CODE=%s REGION=%s %s", activationId, activationCode, region, SSMRegistrationScript)}
if _, err := s.podExecUtil.ExecInPod(ctx, pod, SSMAgentSidecarContainerName, cmd); err != nil {

// Use stdin to pass only sensitive values securely
cmd := []string{"bash", "-c", fmt.Sprintf("read ACTIVATION_ID && read ACTIVATION_CODE && env ACTIVATION_ID=\"$ACTIVATION_ID\" ACTIVATION_CODE=\"$ACTIVATION_CODE\" REGION=%s %s", region, SSMRegistrationScript)}
stdinData := fmt.Sprintf("%s\n%s\n", activationId, activationCode)

if _, err := s.podExecUtil.ExecInPod(ctx, pod, SSMAgentSidecarContainerName, cmd, stdinData); err != nil {
return fmt.Errorf("failed to execute SSM registration script: %w", err)
}

// Step 3: Start remote access server in main container
logger.Info("Starting remote access server in main container")
serverCmd := []string{"bash", "-c", fmt.Sprintf("sudo %s > /dev/null 2>&1 &", RemoteAccessServerPath)}
if _, err := s.podExecUtil.ExecInPod(ctx, pod, WorkspaceContainerName, serverCmd); err != nil {
if _, err := s.podExecUtil.ExecInPod(ctx, pod, WorkspaceContainerName, serverCmd, noStdin); err != nil {
return fmt.Errorf("failed to start remote access server: %w", err)
}

// Step 4: Create completion marker file
logger.Info("Creating SSM registration completion marker")
markerCmd := []string{"touch", SSMRegistrationMarkerFile}
if _, err := s.podExecUtil.ExecInPod(ctx, pod, SSMAgentSidecarContainerName, markerCmd); err != nil {
if _, err := s.podExecUtil.ExecInPod(ctx, pod, SSMAgentSidecarContainerName, markerCmd, noStdin); err != nil {
return fmt.Errorf("failed to create completion marker: %w", err)
}

Expand Down
26 changes: 17 additions & 9 deletions internal/aws/ssm_remote_access_strategy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ type MockPodExecUtil struct {
mock.Mock
}

func (m *MockPodExecUtil) ExecInPod(ctx context.Context, pod *corev1.Pod, containerName string, cmd []string) (string, error) {
args := m.Called(ctx, pod, containerName, cmd)
func (m *MockPodExecUtil) ExecInPod(ctx context.Context, pod *corev1.Pod, containerName string, cmd []string, stdin string) (string, error) {
args := m.Called(ctx, pod, containerName, cmd, stdin)
return args.String(0), args.Error(1)
}

Expand Down Expand Up @@ -227,7 +227,7 @@ func TestInitSSMAgent_AlreadyCompleted(t *testing.T) {
// Create mock PodExecUtil that simulates completed registration
mockPodExecUtil := &MockPodExecUtil{}
mockPodExecUtil.On("ExecInPod", mock.Anything, mock.Anything, SSMAgentSidecarContainerName,
[]string{"test", "-f", SSMRegistrationMarkerFile}).Return("", nil) // File exists
[]string{"test", "-f", SSMRegistrationMarkerFile}, "").Return("", nil) // File exists

mockSSMClient := &MockSSMRemoteAccessClient{}
strategy, err := NewSSMRemoteAccessStrategy(mockSSMClient, mockPodExecUtil)
Expand Down Expand Up @@ -260,25 +260,31 @@ func TestInitSSMAgent_SuccessFlow(t *testing.T) {

// First call: check if registration completed (return error = not completed)
mockPodExecUtil.On("ExecInPod", mock.Anything, mock.Anything, SSMAgentSidecarContainerName,
[]string{"test", "-f", SSMRegistrationMarkerFile}).Return("", errors.New("file not found"))
[]string{"test", "-f", SSMRegistrationMarkerFile}, "").Return("", errors.New("file not found"))

// Second call: registration script execution
// Second call: registration script execution with stdin
mockPodExecUtil.On("ExecInPod", mock.Anything, mock.Anything, SSMAgentSidecarContainerName,
mock.MatchedBy(func(cmd []string) bool {
return len(cmd) == 3 && cmd[0] == bashCommand && cmd[1] == "-c" &&
strings.Contains(cmd[2], "read ACTIVATION_ID && read ACTIVATION_CODE") &&
strings.Contains(cmd[2], "REGION=us-west-2") &&
strings.Contains(cmd[2], "register-ssm.sh")
}), mock.MatchedBy(func(stdin string) bool {
return strings.Contains(stdin, "test-activation-id") &&
strings.Contains(stdin, "test-activation-code") &&
!strings.Contains(stdin, "us-west-2") // Region should not be in stdin
})).Return("", nil)

// Third call: remote access server start
mockPodExecUtil.On("ExecInPod", mock.Anything, mock.Anything, WorkspaceContainerName,
mock.MatchedBy(func(cmd []string) bool {
return len(cmd) == 3 && cmd[0] == bashCommand && cmd[1] == "-c" &&
strings.Contains(cmd[2], "remote-access-server")
})).Return("", nil)
}), "").Return("", nil)

// Fourth call: completion marker creation
mockPodExecUtil.On("ExecInPod", mock.Anything, mock.Anything, SSMAgentSidecarContainerName,
[]string{"touch", SSMRegistrationMarkerFile}).Return("", nil)
[]string{"touch", SSMRegistrationMarkerFile}, "").Return("", nil)

// Create mock SSM client
mockSSMClient := &MockSSMRemoteAccessClient{}
Expand Down Expand Up @@ -322,15 +328,17 @@ func TestInitSSMAgent_RegistrationFailure(t *testing.T) {

// First call: check if registration completed (return error = not completed)
mockPodExecUtil.On("ExecInPod", mock.Anything, mock.Anything, SSMAgentSidecarContainerName,
[]string{"test", "-f", SSMRegistrationMarkerFile}).Return("", errors.New("file not found"))
[]string{"test", "-f", SSMRegistrationMarkerFile}, "").Return("", errors.New("file not found"))

// Second call: registration script execution fails
expectedError := errors.New("registration script failed")
mockPodExecUtil.On("ExecInPod", mock.Anything, mock.Anything, SSMAgentSidecarContainerName,
mock.MatchedBy(func(cmd []string) bool {
return len(cmd) == 3 && cmd[0] == bashCommand && cmd[1] == "-c" &&
strings.Contains(cmd[2], "read ACTIVATION_ID && read ACTIVATION_CODE") &&
strings.Contains(cmd[2], "REGION=us-west-2") &&
strings.Contains(cmd[2], "register-ssm.sh")
})).Return("", expectedError)
}), mock.Anything).Return("", expectedError)

// Create mock SSM client
mockSSMClient := &MockSSMRemoteAccessClient{}
Expand Down
22 changes: 16 additions & 6 deletions internal/controller/pod_exec_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ func NewPodExecUtil() (*PodExecUtil, error) {
}, nil
}

// ExecInPod executes a command in a specific container of a pod
func (p *PodExecUtil) ExecInPod(ctx context.Context, pod *corev1.Pod, containerName string, cmd []string) (string, error) {
// ExecInPod executes a command in a specific container of a pod with optional stdin input
func (p *PodExecUtil) ExecInPod(ctx context.Context, pod *corev1.Pod, containerName string, cmd []string, stdin string) (string, error) {
logger := logf.FromContext(ctx).WithValues("pod", pod.Name, "container", containerName, "cmd", cmd)

// Create exec request
Expand All @@ -57,9 +57,12 @@ func (p *PodExecUtil) ExecInPod(ctx context.Context, pod *corev1.Pod, containerN
Namespace(pod.Namespace).
SubResource("exec")

// Enable stdin only if we have stdin data
hasStdin := stdin != ""
req.VersionedParams(&corev1.PodExecOptions{
Container: containerName,
Command: cmd,
Stdin: hasStdin,
Stdout: true,
Stderr: true,
}, scheme.ParameterCodec)
Expand All @@ -71,17 +74,24 @@ func (p *PodExecUtil) ExecInPod(ctx context.Context, pod *corev1.Pod, containerN
}

var stdout, stderr bytes.Buffer
err = exec.StreamWithContext(ctx, remotecommand.StreamOptions{
streamOptions := remotecommand.StreamOptions{
Stdout: &stdout,
Stderr: &stderr,
})
}

// Add stdin only if we have data
if hasStdin {
streamOptions.Stdin = strings.NewReader(stdin)
}

err = exec.StreamWithContext(ctx, streamOptions)

output := strings.TrimSpace(stdout.String())
if err != nil {
logger.V(1).Info("Command execution failed", "error", err, "stderr", stderr.String())
logger.V(1).Info("Command execution failed", "hasStdin", hasStdin, "error", err, "stderr", stderr.String())
return output, err
}

logger.V(1).Info("Command executed successfully", "output", output)
logger.V(1).Info("Command executed successfully", "hasStdin", hasStdin, "output", output)
return output, nil
}
98 changes: 94 additions & 4 deletions internal/controller/pod_exec_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,15 @@ type mockExecutor struct {
stderr string
}

// Mock executor that captures StreamOptions
type mockExecutorWithCapture struct {
streamErr error
stdout string
stderr string
capturedStdin *bool
capturedStdinData *string
}

func (m *mockExecutor) Stream(options remotecommand.StreamOptions) error {
return m.StreamWithContext(context.Background(), options)
}
Expand All @@ -134,6 +143,34 @@ func (m *mockExecutor) StreamWithContext(ctx context.Context, options remotecomm
return nil
}

func (m *mockExecutorWithCapture) Stream(options remotecommand.StreamOptions) error {
return m.StreamWithContext(context.Background(), options)
}

func (m *mockExecutorWithCapture) StreamWithContext(ctx context.Context, options remotecommand.StreamOptions) error {
// Capture whether stdin was provided
if m.capturedStdin != nil {
*m.capturedStdin = options.Stdin != nil
}

// Capture actual stdin data
if options.Stdin != nil && m.capturedStdinData != nil {
buf := make([]byte, 1024)
n, _ := options.Stdin.Read(buf)
*m.capturedStdinData = string(buf[:n])
}

// Write mock output to provided streams
if options.Stdout != nil && m.stdout != "" {
_, _ = options.Stdout.Write([]byte(m.stdout))
}
if options.Stderr != nil && m.stderr != "" {
_, _ = options.Stderr.Write([]byte(m.stderr))
}

return m.streamErr
}

func TestExecInPod_Success(t *testing.T) {
// This test would require complex REST client mocking
// For now, we'll test the integration with real kubeconfig if available
Expand Down Expand Up @@ -164,8 +201,9 @@ func TestExecInPod_Success(t *testing.T) {
},
}

// Test successful execution
output, err := util.ExecInPod(context.Background(), pod, "test-container", []string{"echo", "hello"})
// Test successful execution without stdin
noStdin := ""
output, err := util.ExecInPod(context.Background(), pod, "test-container", []string{"echo", "hello"}, noStdin)

if err != nil {
t.Fatalf("Expected no error, got: %v", err)
Expand Down Expand Up @@ -200,7 +238,8 @@ func TestExecInPod_ExecutorCreationFailure(t *testing.T) {
}

// Test executor creation failure
output, err := util.ExecInPod(context.Background(), pod, "test-container", []string{"echo", "hello"})
noStdin := ""
output, err := util.ExecInPod(context.Background(), pod, "test-container", []string{"echo", "hello"}, noStdin)

if err == nil {
t.Fatal("Expected error when executor creation fails")
Expand Down Expand Up @@ -242,7 +281,8 @@ func TestExecInPod_StreamExecutionFailure(t *testing.T) {
}

// Test stream execution failure
output, err := util.ExecInPod(context.Background(), pod, "test-container", []string{"failing-command"})
noStdin := ""
output, err := util.ExecInPod(context.Background(), pod, "test-container", []string{"failing-command"}, noStdin)

if err == nil {
t.Fatal("Expected error when stream execution fails")
Expand All @@ -255,3 +295,53 @@ func TestExecInPod_StreamExecutionFailure(t *testing.T) {
t.Errorf("Expected 'partial output', got: '%s'", output)
}
}

func TestExecInPod_WithStdin(t *testing.T) {
util, err := NewPodExecUtil()
if err != nil {
t.Skipf("Skipping integration test - requires valid Kubernetes config: %v", err)
return
}

// Save original
original := newSPDYExecutor
defer func() { newSPDYExecutor = original }()

// Capture stdin data from StreamOptions
var stdinProvided bool
var stdinData string
mockExec := &mockExecutorWithCapture{
stdout: "stdin processed",
capturedStdin: &stdinProvided,
capturedStdinData: &stdinData,
}
newSPDYExecutor = func(config *rest.Config, method string, url *url.URL) (remotecommand.Executor, error) {
return mockExec, nil
}

// Create test pod
pod := &corev1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: "test-pod",
Namespace: "test-namespace",
},
}

// Test execution with stdin
inputData := "test-input\nsecond-line\n"
output, err := util.ExecInPod(context.Background(), pod, "test-container", []string{"bash", "-c", "read line1 && read line2 && echo processed"}, inputData)

if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}
if output != "stdin processed" {
t.Errorf("Expected 'stdin processed', got: '%s'", output)
}
if !stdinProvided {
t.Error("Expected StreamOptions.Stdin to be provided when stdin data is given")
}
expectedData := "test-input\nsecond-line\n"
if stdinData != expectedData {
t.Errorf("Expected stdin data '%s', got '%s'", expectedData, stdinData)
}
}