Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions internal/adapters/data/ssh_config_file/config_io.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ func (r *Repository) loadConfig() (*ssh_config.Config, error) {
return nil, fmt.Errorf("failed to decode config: %w", err)
}

// Preserve the implicit host (global directives like Include)
if len(cfg.Hosts) > 0 && cfg.Hosts[0].Implicit {
r.implicitHost = cfg.Hosts[0]
}

return cfg, nil
}

Expand All @@ -62,6 +67,12 @@ func (r *Repository) saveConfig(cfg *ssh_config.Config) error {
}
}()

// Restore the implicit host with global directives (Include, etc.) before saving
if r.implicitHost != nil {
// Prepend the implicit host to preserve global directives
cfg.Hosts = append([]*ssh_config.Host{r.implicitHost}, cfg.Hosts...)
}

if err := r.writeConfigToFile(tempFile, cfg); err != nil {
return fmt.Errorf("failed to write config to temporary file: %w", err)
}
Expand Down
164 changes: 164 additions & 0 deletions internal/adapters/data/ssh_config_file/include_preservation_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
// Copyright 2025.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package ssh_config_file

import (
"os"
"path/filepath"
"strings"
"testing"

"github.com/Adembc/lazyssh/internal/core/domain"
"go.uber.org/zap"
)

// TestIncludeDirectivesPreservation tests that Include directives are preserved
// when adding, updating, or deleting server entries
func TestIncludeDirectivesPreservation(t *testing.T) {
// Create a temporary directory for the test
tempDir := t.TempDir()
configPath := filepath.Join(tempDir, "config")
metadataPath := filepath.Join(tempDir, "metadata.json")

// Initial config with Include directives
initialConfig := `# =====================================================================
# Includes
# =====================================================================
Include terraform.d/*
Include personal.d/*
Include work.d/*
Include ~/.orbstack/ssh/config

# =====================================================================
# Personal Servers
# =====================================================================
Host testserver
HostName test.example.com
User testuser
Port 22
`

// Write initial config
err := os.WriteFile(configPath, []byte(initialConfig), 0o600)
if err != nil {
t.Fatalf("Failed to write initial config: %v", err)
}

// Create repository
logger := zap.NewNop().Sugar()
repo := NewRepository(logger, configPath, metadataPath)

// Test 1: Add a new server
newServer := domain.Server{
Alias: "newserver",
Host: "new.example.com",
User: "newuser",
Port: 2222,
}

err = repo.AddServer(newServer)
if err != nil {
t.Fatalf("Failed to add server: %v", err)
}

// Read the config file and verify Include directives are preserved
content, err := os.ReadFile(configPath)
if err != nil {
t.Fatalf("Failed to read config: %v", err)
}

configStr := string(content)

// Check that all Include directives are present
includeDirectives := []string{
"Include terraform.d/*",
"Include personal.d/*",
"Include work.d/*",
"Include ~/.orbstack/ssh/config",
}

for _, directive := range includeDirectives {
if !strings.Contains(configStr, directive) {
t.Errorf("Include directive missing after AddServer: %s\nConfig content:\n%s", directive, configStr)
}
}

// Verify the new server was added
if !strings.Contains(configStr, "Host newserver") {
t.Errorf("New server not found in config")
}

// Test 2: Update existing server
servers, err := repo.ListServers("")
if err != nil {
t.Fatalf("Failed to get servers: %v", err)
}

var testServer domain.Server
for _, s := range servers {
if s.Alias == "testserver" {
testServer = s
break
}
}

updatedServer := testServer
updatedServer.Port = 2200
err = repo.UpdateServer(testServer, updatedServer)
if err != nil {
t.Fatalf("Failed to update server: %v", err)
}

// Read config again
content, err = os.ReadFile(configPath)
if err != nil {
t.Fatalf("Failed to read config: %v", err)
}

configStr = string(content)

// Check that Include directives are still present
for _, directive := range includeDirectives {
if !strings.Contains(configStr, directive) {
t.Errorf("Include directive missing after UpdateServer: %s\nConfig content:\n%s", directive, configStr)
}
}

// Test 3: Delete a server
err = repo.DeleteServer(newServer)
if err != nil {
t.Fatalf("Failed to delete server: %v", err)
}

// Read config again
content, err = os.ReadFile(configPath)
if err != nil {
t.Fatalf("Failed to read config: %v", err)
}

configStr = string(content)

// Check that Include directives are still present
for _, directive := range includeDirectives {
if !strings.Contains(configStr, directive) {
t.Errorf("Include directive missing after DeleteServer: %s\nConfig content:\n%s", directive, configStr)
}
}

// Verify the server was deleted
if strings.Contains(configStr, "Host newserver") {
t.Errorf("Deleted server still present in config")
}
}
56 changes: 55 additions & 1 deletion internal/adapters/data/ssh_config_file/mapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
package ssh_config_file

import (
"reflect"
"strconv"
"strings"
"time"
"unsafe"

"github.com/Adembc/lazyssh/internal/core/domain"
"github.com/kevinburke/ssh_config"
Expand All @@ -26,8 +28,25 @@ import (
// toDomainServer converts ssh_config.Config to a slice of domain.Server.
func (r *Repository) toDomainServer(cfg *ssh_config.Config) []domain.Server {
servers := make([]domain.Server, 0, len(cfg.Hosts))

// Process all hosts including those from Include directives
servers = r.extractHostsRecursively(cfg, servers)

return servers
}

// extractHostsRecursively extracts hosts from a config and recursively from any Include directives
func (r *Repository) extractHostsRecursively(cfg *ssh_config.Config, servers []domain.Server) []domain.Server {
for _, host := range cfg.Hosts {
// First, check for Include directives in this host's nodes
for _, node := range host.Nodes {
if inc, ok := node.(*ssh_config.Include); ok {
// Use reflection to access the private 'files' field
servers = r.extractHostsFromInclude(inc, servers)
}
}

// Then process the host itself (skip wildcards as before)
aliases := make([]string, 0, len(host.Patterns))

for _, pattern := range host.Patterns {
Expand All @@ -41,6 +60,7 @@ func (r *Repository) toDomainServer(cfg *ssh_config.Config) []domain.Server {
if len(aliases) == 0 {
continue
}

server := domain.Server{
Alias: aliases[0],
Aliases: aliases,
Expand All @@ -63,7 +83,41 @@ func (r *Repository) toDomainServer(cfg *ssh_config.Config) []domain.Server {
return servers
}

// mapKVToServer maps an ssh_config.KV node to the corresponding fields in domain.Server.
// extractHostsFromInclude uses reflection to extract hosts from Include nodes
func (r *Repository) extractHostsFromInclude(inc *ssh_config.Include, servers []domain.Server) []domain.Server {
// Use reflection to access the private 'files' and 'matches' fields
incValue := reflect.ValueOf(inc).Elem()
filesField := incValue.FieldByName("files")
matchesField := incValue.FieldByName("matches")

if !filesField.IsValid() || filesField.IsNil() || !matchesField.IsValid() {
return servers
}

// matches is a []string slice - iterate through it
matchesLen := matchesField.Len()
for i := 0; i < matchesLen; i++ {
matchKey := matchesField.Index(i).String()

// Get the corresponding Config from the files map
cfgValue := filesField.MapIndex(reflect.ValueOf(matchKey))
if !cfgValue.IsValid() || cfgValue.IsNil() {
continue
}

// cfgValue is a reflect.Value pointing to *Config
// We need to use Elem() to dereference the pointer, then get the pointer again
cfgPtr := cfgValue.Elem()

// Construct a *Config from the pointer address
// This is a workaround for not being able to call Interface() on unexported fields
//nolint:gosec // G103: Using unsafe to access unexported field from ssh_config library
includedCfg := (*ssh_config.Config)(unsafe.Pointer(cfgPtr.UnsafeAddr())) // Recursively process the included config
servers = r.extractHostsRecursively(includedCfg, servers)
}

return servers
} // mapKVToServer maps an ssh_config.KV node to the corresponding fields in domain.Server.
func (r *Repository) mapKVToServer(server *domain.Server, kvNode *ssh_config.KV) {
key := strings.ToLower(kvNode.Key)
value := kvNode.Value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ type Repository struct {
fileSystem FileSystem
metadataManager *metadataManager
logger *zap.SugaredLogger
// implicitHost preserves global-level directives (Include, etc.)
// that appear before any explicit Host blocks
implicitHost *ssh_config.Host
}

// NewRepository creates a new SSH config repository.
Expand Down
Loading