From 10f783d4326aa9905965aeac249815663789264d Mon Sep 17 00:00:00 2001 From: Kimmo Lehto Date: Tue, 2 Jan 2024 14:04:20 +0200 Subject: [PATCH] Refactor cmd/rigtest into a real test suite in test/rig_test.go and fix found bugs (#145) * Convert cmd/rigtest into a real test in test/rig_test.go Signed-off-by: Kimmo Lehto * Adjust workflows/test scripts for test/rig_test.go Signed-off-by: Kimmo Lehto * Fix bugs found by the new improved tests Signed-off-by: Kimmo Lehto --------- Signed-off-by: Kimmo Lehto --- .github/workflows/go.yml | 39 +- cmd/rigtest/rigtest.go | 420 ------------------ connection.go | 21 +- connection_test.go | 2 +- exec/exec.go | 27 +- go.mod | 12 +- go.sum | 43 +- localhost.go | 6 +- openssh.go | 2 +- os/host.go | 3 + os/linux.go | 18 +- os/windows.go | 76 ++-- pkg/powershell/powershell.go | 13 +- pkg/rigfs/direntrybuffer.go | 50 +++ pkg/rigfs/direntrybuffer_test.go | 59 +++ pkg/rigfs/posixfsys.go | 330 +++++++-------- pkg/rigfs/rigrcp.ps1 | 324 +++++--------- pkg/rigfs/types.go | 12 +- pkg/rigfs/windir.go | 86 ++++ pkg/rigfs/winfile.go | 260 ++++++++++++ pkg/rigfs/winfileinfo.go | 84 ++++ pkg/rigfs/winfsys.go | 518 ++++------------------ pkg/rigfs/withname.go | 24 ++ ssh.go | 75 +++- test/Makefile | 22 +- test/rig_test.go | 707 +++++++++++++++++++++++++++++++ test/test.sh | 124 +++--- winrm.go | 94 ++-- 28 files changed, 1938 insertions(+), 1513 deletions(-) delete mode 100644 cmd/rigtest/rigtest.go create mode 100644 pkg/rigfs/direntrybuffer.go create mode 100644 pkg/rigfs/direntrybuffer_test.go create mode 100644 pkg/rigfs/windir.go create mode 100644 pkg/rigfs/winfile.go create mode 100644 pkg/rigfs/winfileinfo.go create mode 100644 pkg/rigfs/withname.go create mode 100644 test/rig_test.go diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index b0519e2a..0b0ea80f 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -4,7 +4,7 @@ on: [pull_request] jobs: - build-linux: + unit-linux: runs-on: ubuntu-20.04 steps: - uses: actions/checkout@v4 @@ -13,13 +13,11 @@ jobs: uses: actions/setup-go@v5 with: go-version-file: go.mod - check-latest: true - - name: Build - run: go build -v ./... - - name: Test - run: go test -v ./... + run: | + go mod download + go test -v ./... integration-linux: strategy: @@ -29,7 +27,7 @@ jobs: - quay.io/k0sproject/bootloose-ubuntu20.04 - quay.io/k0sproject/bootloose-debian12 - quay.io/k0sproject/bootloose-alpine3.18 - needs: build-linux + needs: unit-linux runs-on: ubuntu-20.04 steps: - uses: actions/checkout@v4 @@ -38,7 +36,6 @@ jobs: uses: actions/setup-go@v5 with: go-version-file: go.mod - check-latest: true - name: install test dependencies run: | @@ -48,12 +45,16 @@ jobs: - name: Run integration tests env: LINUX_IMAGE: ${{ matrix.image }} - run: make -C test test + run: | + cd test + go mod download + make test windows: runs-on: windows-2022 steps: - name: Set up WinRM + shell: pwsh run: | Set-Item WSMan:\localhost\Service\AllowUnencrypted -Value $True Get-ChildItem WSMan:\Localhost\listener | Remove-Item -Recurse @@ -102,16 +103,14 @@ jobs: uses: actions/setup-go@v5 with: go-version-file: go.mod - - - name: Test - run: go test -v ./... - - - name: Build rigtest - run: | - go install ./cmd/rigtest - rigtest --help - - - name: Run rigtest + - name: Unit test + run: | + go mod download + go test -v ./... + + - name: Integration test run: | - rigtest.exe -proto winrm -host 127.0.0.1:5986 -user winrmuser -pass Password123 -https + cd test + go mod download + go test -v ./ -args -protocol winrm -host 127.0.0.1 -port 5986 -user winrmuser -winrm-password Password123 -winrm-https diff --git a/cmd/rigtest/rigtest.go b/cmd/rigtest/rigtest.go deleted file mode 100644 index d4ae87a5..00000000 --- a/cmd/rigtest/rigtest.go +++ /dev/null @@ -1,420 +0,0 @@ -package main - -import ( - "crypto/rand" - "crypto/sha256" - "errors" - "flag" - "fmt" - "io" - "io/fs" - goos "os" - "path/filepath" - "strconv" - "strings" - "time" - - "github.com/k0sproject/rig" - "github.com/k0sproject/rig/exec" - "github.com/k0sproject/rig/os" - "github.com/k0sproject/rig/os/registry" - _ "github.com/k0sproject/rig/os/support" - "github.com/k0sproject/rig/pkg/rigfs" - "github.com/kevinburke/ssh_config" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -type TestingT interface { - Errorf(format string, args ...any) - FailNow() -} - -type testRunner struct{} - -func (t testRunner) Run(name string, args ...any) { - fmt.Println("* Running test:", fmt.Sprintf(name, args...)) -} - -func (t testRunner) Errorf(format string, args ...any) { - println(fmt.Sprintf(format, args...)) -} - -func (t testRunner) FailNow() { - panic("fail") -} - -func (t testRunner) Fail(msg string) { - panic("fail: " + msg) -} - -func (t testRunner) Err(err error) { - panic("fail: " + err.Error()) -} - -type configurer interface { - WriteFile(os.Host, string, string, string) error - LineIntoFile(os.Host, string, string, string) error - ReadFile(os.Host, string) (string, error) - FileExist(os.Host, string) bool - DeleteFile(os.Host, string) error - Stat(os.Host, string, ...exec.Option) (*os.FileInfo, error) - Touch(os.Host, string, time.Time, ...exec.Option) error - MkDir(os.Host, string, ...exec.Option) error -} - -// Host is a host that utilizes rig for connections -type Host struct { - rig.Connection - - Configurer configurer -} - -// LoadOS is a function that assigns a OS support package to the host and -// typecasts it to a suitable interface -func (h *Host) LoadOS() error { - bf, err := registry.GetOSModuleBuilder(*h.OSVersion) - if err != nil { - return err - } - - h.Configurer = bf().(configurer) - - return nil -} - -func retry(fn func() error) error { - var err error - for i := 0; i < 3; i++ { - err = fn() - if err == nil { - return nil - } - time.Sleep(2 * time.Second) - } - return err -} - -func main() { - dh := flag.String("host", "127.0.0.1", "target host [+ :port], can give multiple comma separated") - usr := flag.String("user", "root", "user name") - proto := flag.String("proto", "ssh", "ssh/winrm/localhost/openssh") - kp := flag.String("keypath", "", "ssh keypath") - pc := flag.Bool("askpass", false, "ask ssh passwords") - pwd := flag.String("pass", "", "winrm password") - https := flag.Bool("https", false, "use https for winrm") - connectOnly := flag.Bool("connect", false, "just connect and quit") - sshKey := flag.String("ssh-private-key", "", "ssh private key") - multiplex := flag.Bool("ssh-multiplex", true, "use ssh multiplexing") - fsysOnly := flag.Bool("fsys", false, "only test rigfs operations") - - fn := fmt.Sprintf("test_%s.txt", time.Now().Format("20060102150405")) - - flag.Parse() - - if *dh == "" { - println("at least host required, see -help") - goos.Exit(1) - } - - if configPath := goos.Getenv("SSH_CONFIG"); configPath != "" { - f, err := goos.Open(configPath) - if err != nil { - panic(err) - } - cfg, err := ssh_config.Decode(f) - if err != nil { - panic(err) - } - rig.SSHConfigGetAll = func(dst, key string) []string { - res, err := cfg.GetAll(dst, key) - if err != nil { - return nil - } - return res - } - } - - var passfunc func() (string, error) - if *pc { - passfunc = func() (string, error) { - var pass string - fmt.Print("Password: ") - fmt.Scanln(&pass) - return pass, nil - } - } - - var hosts []*Host - - for _, address := range strings.Split(*dh, ",") { - port := 22 - if addr, portstr, ok := strings.Cut(address, ":"); ok { - address = addr - p, err := strconv.Atoi(portstr) - if err != nil { - panic("invalid port " + portstr) - } - port = p - } - - var h *Host - switch *proto { - case "ssh": - if *sshKey != "" { - // test with private key in a string - authM, err := rig.ParseSSHPrivateKey([]byte(*sshKey), rig.DefaultPasswordCallback) - if err != nil { - panic(err) - } - h = &Host{ - Connection: rig.Connection{ - SSH: &rig.SSH{ - Address: address, - Port: port, - User: *usr, - AuthMethods: authM, - }, - }, - } - } else { - h = &Host{ - Connection: rig.Connection{ - SSH: &rig.SSH{ - Address: address, - Port: port, - User: *usr, - KeyPath: kp, - PasswordCallback: passfunc, - }, - }, - } - } - case "winrm": - h = &Host{ - Connection: rig.Connection{ - WinRM: &rig.WinRM{ - Address: address, - Port: port, - User: *usr, - UseHTTPS: *https, - Insecure: true, - Password: *pwd, - }, - }, - } - case "localhost": - h = &Host{ - Connection: rig.Connection{ - Localhost: &rig.Localhost{ - Enabled: true, - }, - }, - } - case "openssh": - h = &Host{ - Connection: rig.Connection{ - OpenSSH: &rig.OpenSSH{ - Address: address, - KeyPath: kp, - DisableMultiplexing: !*multiplex, - }, - }, - } - if *usr != "" { - h.OpenSSH.User = usr - } - if port != 22 && port != 0 { - h.OpenSSH.Port = &port - } - if cfgPath := goos.Getenv("SSH_CONFIG"); cfgPath != "" { - h.OpenSSH.ConfigPath = &cfgPath - } - default: - panic("unknown protocol " + *proto) - } - hosts = append(hosts, h) - } - - t := testRunner{} - - for _, h := range hosts { - t.Run("connect %s", h.Address()) - err := retry(func() error { - err := h.Connect() - if errors.Is(err, rig.ErrCantConnect) { - t.Err(err) - } - return err - }) - - require.NoError(t, err, "connection failed") - - if *connectOnly { - continue - } - - if !*fsysOnly { - t.Run("load os %s", h.Address()) - require.NoError(t, h.LoadOS(), "load os") - - t.Run("os support module functions on %s", h) - - stat, err := h.Configurer.Stat(h, fn) - require.Error(t, err, "no stat error") - - now := time.Now() - err = h.Configurer.Touch(h, fn, now) - require.NoError(t, err, "touch error") - - stat, err = h.Configurer.Stat(h, fn) - require.NoError(t, err, "stat error") - assert.Equal(t, filepath.Base(stat.Name()), filepath.Base(fn), "stat name not as expected") - assert.Equal(t, filepath.Base(stat.Name()), filepath.Base(fn), "stat name not as expected") - assert.Condition(t, func() bool { - actual := stat.ModTime() - return now.Equal(actual) || now.Truncate(time.Second).Equal(actual) - }, "Expected %s, got %s", now, stat.ModTime()) - - require.NoError(t, h.Configurer.WriteFile(h, fn, "test\ntest2\ntest3", "0644"), "write file") - if !h.Configurer.FileExist(h, fn) { - t.Fail("file does not exist after write") - } - require.NoError(t, h.Configurer.LineIntoFile(h, fn, "test2", "test4"), "line into file") - - row, err := h.Configurer.ReadFile(h, fn) - require.NoError(t, err, "read file") - require.Equal(t, "test\ntest4\ntest3", row, "file content not as expected after line into file") - - require.NoError(t, h.Configurer.DeleteFile(h, fn)) - require.False(t, h.Configurer.FileExist(h, fn)) - } - - fsyses := []rigfs.Fsys{h.Fsys()} - if !h.IsWindows() { - // on windows using sudo makes no difference - the commands will be executed identically - // you just might not have permissions to do so. the only access elevation for command line - // on windows is "runas /user:Administrator" which requires you to enter the password of - // the Administator account. - // - // on linux, we'll test the sudo fsys as well - fsyses = append(fsyses, h.SudoFsys()) - } - - for idx, fsys := range fsyses { - for _, testFileSize := range []int64{ - int64(500), // less than one block on most filesystems - int64(1 << (10 * 2)), // exactly 1MB - int64(4096), // exactly one block on most filesystems - int64(4097), // plus 1 - } { - t.Run("fsys (%d) functions for file size %d on %s", idx+1, testFileSize, h) - - origin := io.LimitReader(rand.Reader, testFileSize) - shasum := sha256.New() - reader := io.TeeReader(origin, shasum) - - destf, err := fsys.OpenFile(fn, goos.O_CREATE|goos.O_WRONLY, 0644) - require.NoError(t, err, "open file using OpenFile") - - n, err := io.Copy(destf, reader) - require.NoError(t, err, "io.copy file from local to remote") - require.Equal(t, testFileSize, n, "file size not as expected after copy") - - require.NoError(t, destf.Close(), "error while closing file") - - fstat, err := fsys.Stat(fn) - require.NoError(t, err, "stat error") - require.Equal(t, testFileSize, fstat.Size(), "file size not as expected in stat result") - - destSum, err := fsys.Sha256(fn) - require.NoError(t, err, "sha256 error") - - require.Equal(t, fmt.Sprintf("%x", shasum.Sum(nil)), destSum, "sha256 mismatch after io.copy from local to remote") - - destf, err = fsys.OpenFile(fn, goos.O_RDONLY, 0) - require.NoError(t, err, "open file for read") - - readSha := sha256.New() - n, err = io.Copy(readSha, destf) - require.NoError(t, err, "io.copy file from remote to local") - - require.Equal(t, testFileSize, n, "file size not as expected after copy from remote to local") - - fstat, err = destf.Stat() - require.NoError(t, err, "stat error after read") - require.Equal(t, testFileSize, fstat.Size(), "file size not as expected in stat result after read") - require.Equal(t, readSha.Sum(nil), shasum.Sum(nil), "sha256 mismatch after io.copy from remote to local") - - _, err = destf.Seek(0, 0) - require.NoError(t, err, "seek") - - readSha.Reset() - - n, err = io.Copy(readSha, destf) - require.NoError(t, err, "io.copy file from remote to local after seek") - - require.Equal(t, testFileSize, n, "file size not as expected after copy from remote to local after seek") - - require.Equal(t, readSha.Sum(nil), shasum.Sum(nil), "sha256 mismatch after io.copy from remote to local after seek") - - require.NoError(t, destf.Close(), "close after seek + read") - require.NoError(t, fsys.Remove(fn), "remove file") - _, err = destf.Stat() - require.ErrorIs(t, err, fs.ErrNotExist, "file still exists") - } - t.Run("fsys (%d) dir ops on %s", idx+1, h) - - // fsys dirops - require.NoError(t, fsys.MkDirAll("rigtmpdir/nested", 0644), "make nested dir") - _, err = fsys.Stat("rigtmpdir") - require.NoError(t, err, "rigtmpdir was not created") - _, err = fsys.Stat("rigtmpdir/nested") - require.NoError(t, err, "tmpdir/nested was not created") - - require.NoError(t, fsys.RemoveAll("rigtmpdir"), "remove recursive") - _, err = fsys.Stat("rigtmpdir/nested") - require.ErrorIs(t, err, fs.ErrNotExist, "nested dir still exists") - _, err = fsys.Stat("rigtmpdir") - require.ErrorIs(t, err, fs.ErrNotExist, "dir still exists") - - // create test dir structure - require.NoError(t, fsys.MkDirAll("rigtmpdir/testdir/subdir", 0755), "make dir") - - for _, fn := range []string{"rigtmpdir/testdir/subdir/testfile1", "rigtmpdir/testdir/testfile2"} { - f, err := fsys.OpenFile(fn, goos.O_CREATE|goos.O_WRONLY, 0644) - require.NoError(t, err, "open file using OpenFile") - _, err = f.Write([]byte("test")) - require.NoError(t, err, "write to file") - require.NoError(t, f.Close(), "close file") - } - - var foundFiles []fs.DirEntry - - err = fs.WalkDir(fsys, "rigtmpdir/testdir", func(path string, d fs.DirEntry, err error) error { - if err != nil { - println("error walking", path, err) - return err - } - info, err := d.Info() - if err != nil { - return err - } - if info.Mode()&fs.ModeIrregular != 0 { - return fs.SkipDir - } - - foundFiles = append(foundFiles, d) - return nil - }) - require.NoError(t, err, "walk dir") - require.Equal(t, 4, len(foundFiles), "walk dir found files") - require.Equal(t, "testdir", foundFiles[0].Name(), "walk dir found subdir") - require.Equal(t, "subdir", foundFiles[1].Name(), "walk dir found subdir") - require.Equal(t, "testfile1", foundFiles[2].Name(), "walk dir found testfile1") - require.Equal(t, "testfile2", foundFiles[3].Name(), "walk dir found testfile2") - } - t.Run("disconnect %s", h.Address()) - h.Disconnect() - } -} diff --git a/connection.go b/connection.go index 5e966f09..736e40d4 100644 --- a/connection.go +++ b/connection.go @@ -12,25 +12,21 @@ import ( "github.com/alessio/shellescape" "github.com/creasty/defaults" - "github.com/google/shlex" "github.com/k0sproject/rig/exec" "github.com/k0sproject/rig/log" rigos "github.com/k0sproject/rig/os" "github.com/k0sproject/rig/pkg/rigfs" + "github.com/mattn/go-shellwords" ) var _ rigos.Host = (*Connection)(nil) -type waiter interface { - Wait() error -} - type client interface { Connect() error Disconnect() IsWindows() bool Exec(cmd string, opts ...exec.Option) error - ExecStreams(cmd string, stdin io.ReadCloser, stdout io.Writer, stderr io.Writer, opts ...exec.Option) (waiter, error) + ExecStreams(cmd string, stdin io.ReadCloser, stdout io.Writer, stderr io.Writer, opts ...exec.Option) (exec.Waiter, error) ExecInteractive(cmd string) error String() string Protocol() string @@ -187,7 +183,7 @@ func (c *Connection) IsWindows() bool { // ExecStreams executes a command on the remote host and uses the passed in streams for stdin, stdout and stderr. It returns a Waiter with a .Wait() function that // blocks until the command finishes and returns an error if the exit code is not zero. -func (c Connection) ExecStreams(cmd string, stdin io.ReadCloser, stdout, stderr io.Writer, opts ...exec.Option) (rigfs.Waiter, error) { +func (c Connection) ExecStreams(cmd string, stdin io.ReadCloser, stdout, stderr io.Writer, opts ...exec.Option) (exec.Waiter, error) { if err := c.checkConnected(); err != nil { return nil, fmt.Errorf("%w: exec with streams: %w", ErrCommandFailed, err) } @@ -255,7 +251,7 @@ func sudoNoop(cmd string) string { } func sudoSudo(cmd string) string { - parts, err := shlex.Split(cmd) + parts, err := shellwords.Parse(cmd) if err != nil { return "sudo -s -- " + cmd } @@ -400,15 +396,20 @@ func (c *Connection) Upload(src, dst string, _ ...exec.Option) error { shasum := sha256.New() fsys := c.Fsys() - remote, err := fsys.OpenFile(dst, os.O_CREATE|os.O_WRONLY, stat.Mode()) + remote, err := fsys.OpenFile(dst, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, stat.Mode()) if err != nil { return fmt.Errorf("%w: open remote file %s for writing: %w", ErrInvalidPath, dst, err) } defer remote.Close() - if _, err := remote.CopyFromN(local, stat.Size(), shasum); err != nil { + localReader := io.TeeReader(local, shasum) + if _, err := io.Copy(remote, localReader); err != nil { + _ = remote.Close() return fmt.Errorf("%w: copy file %s to remote host: %w", ErrUploadFailed, dst, err) } + if err := remote.Close(); err != nil { + return fmt.Errorf("%w: close remote file %s: %w", ErrUploadFailed, dst, err) + } log.Debugf("%s: post-upload validate checksum of %s", c, dst) remoteSum, err := fsys.Sha256(dst) diff --git a/connection_test.go b/connection_test.go index 07644c2e..f19acff2 100644 --- a/connection_test.go +++ b/connection_test.go @@ -39,7 +39,7 @@ func (m *mockClient) Exec(cmd string, opts ...exec.Option) error { return nil } -func (m *mockClient) ExecStreams(cmd string, stdin io.ReadCloser, stdout, stderr io.Writer, opts ...exec.Option) (waiter, error) { +func (m *mockClient) ExecStreams(cmd string, stdin io.ReadCloser, stdout, stderr io.Writer, opts ...exec.Option) (exec.Waiter, error) { return nil, fmt.Errorf("not implemented") } diff --git a/exec/exec.go b/exec/exec.go index dddf61e8..131746d9 100644 --- a/exec/exec.go +++ b/exec/exec.go @@ -3,6 +3,7 @@ package exec import ( "bufio" + "encoding/base64" "fmt" "io" "os" @@ -47,6 +48,11 @@ var ( mutex sync.Mutex ) +// Waiter is a process that can be waited to finish +type Waiter interface { + Wait() error +} + // Option is a functional option for the exec package type Option func(*Options) @@ -85,6 +91,23 @@ func (o *Options) Command(cmd string) (string, error) { return out, nil } +func decodeEncoded(cmd string) string { + if !strings.Contains(cmd, "powershell") { + return cmd + } + + parts := strings.Split(cmd, " ") + for i, p := range parts { + if p == "-E" || p == "-EncodedCommand" && len(parts) > i+1 { + decoded, err := base64.StdEncoding.DecodeString(parts[i+1]) + if err == nil { + parts[i+1] = strings.ReplaceAll(string(decoded), "\x00", "") + } + } + } + return strings.Join(parts, " ") +} + // LogCmd is for logging the command to be executed func (o *Options) LogCmd(prefix, cmd string) { if Confirm { @@ -97,9 +120,9 @@ func (o *Options) LogCmd(prefix, cmd string) { } if o.LogCommand { - DebugFunc("%s: executing `%s`", prefix, o.Redact(cmd)) + DebugFunc("%s: executing `%s`", prefix, o.Redact(decodeEncoded(cmd))) } else { - DebugFunc("%s: executing [REDACTED]", prefix) + DebugFunc("%s: executing command", prefix) } } diff --git a/go.mod b/go.mod index ced8c3b4..39905b49 100644 --- a/go.mod +++ b/go.mod @@ -8,10 +8,9 @@ require ( github.com/alessio/shellescape v1.4.2 github.com/creasty/defaults v1.7.0 github.com/davidmz/go-pageant v1.0.2 - github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 - github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 github.com/kevinburke/ssh_config v1.2.0 - github.com/masterzen/winrm v0.0.0-20220917170901-b07f6cb0598d + github.com/masterzen/winrm v0.0.0-20231128182143-52a9e15d5730 + github.com/mattn/go-shellwords v1.0.12 github.com/mitchellh/go-homedir v1.1.0 github.com/stretchr/testify v1.8.4 golang.org/x/crypto v0.17.0 @@ -21,8 +20,12 @@ require ( require ( github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 // indirect github.com/ChrisTrenkamp/goxpath v0.0.0-20210404020558-97928f7e12b6 // indirect + github.com/bodgit/ntlmssp v0.0.0-20231128222409-0a45a2447e7c // indirect + github.com/bodgit/windows v1.0.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/go-logr/logr v1.3.0 // indirect github.com/gofrs/uuid v4.4.0+incompatible // indirect + github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/hashicorp/go-uuid v1.0.3 // indirect github.com/jcmturner/aescts/v2 v2.0.0 // indirect github.com/jcmturner/dnsutils/v2 v2.0.0 // indirect @@ -32,8 +35,9 @@ require ( github.com/jcmturner/rpc/v2 v2.0.3 // indirect github.com/masterzen/simplexml v0.0.0-20190410153822-31eea3082786 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/tidwall/transform v0.0.0-20201103190739-32f242e2dbde // indirect golang.org/x/mod v0.10.0 // indirect - golang.org/x/net v0.17.0 // indirect + golang.org/x/net v0.19.0 // indirect golang.org/x/sys v0.15.0 // indirect golang.org/x/text v0.14.0 // indirect golang.org/x/tools v0.9.3 // indirect diff --git a/go.sum b/go.sum index 95310588..f7388a1c 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,3 @@ -github.com/Azure/go-ntlmssp v0.0.0-20211209120228-48547f28849e/go.mod h1:chxPXzSsl7ZWRAuOIE23GDNzjWuZquvFlgA8xmpunjU= github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 h1:mFRzDkZVAjdal+s7s0MwaRv9igoPqLRdzOLzw/8Xvq8= github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358/go.mod h1:chxPXzSsl7ZWRAuOIE23GDNzjWuZquvFlgA8xmpunjU= github.com/ChrisTrenkamp/goxpath v0.0.0-20210404020558-97928f7e12b6 h1:w0E0fgc1YafGEh5cROhlROMWXiNoZqApk2PDN0M1+Ns= @@ -9,6 +8,10 @@ github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d h1:licZJFw2RwpH github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d/go.mod h1:asat636LX7Bqt5lYEZ27JNDcqxfjdBQuJ/MM4CN/Lzo= github.com/alessio/shellescape v1.4.2 h1:MHPfaU+ddJ0/bYWpgIeUnQUqKrlJ1S7BfEYPM4uEoM0= github.com/alessio/shellescape v1.4.2/go.mod h1:PZAiSCk0LJaZkiCSkPv8qIobYglO3FPpyFjDCtHLS30= +github.com/bodgit/ntlmssp v0.0.0-20231128222409-0a45a2447e7c h1:W7dgPjcG1Rx+wOtKeviTPYSrzf/Um/ew9wBXY0IjP2s= +github.com/bodgit/ntlmssp v0.0.0-20231128222409-0a45a2447e7c/go.mod h1:X0rVAs8xRc5mkV/xTR3yCWqCN64bVnLxjqI8DHAbN0k= +github.com/bodgit/windows v1.0.1 h1:tF7K6KOluPYygXa3Z2594zxlkbKPAOvqr97etrGNIz4= +github.com/bodgit/windows v1.0.1/go.mod h1:a6JLwrB4KrTR5hBpp8FI9/9W9jJfeQ2h4XDXU74ZCdM= github.com/creasty/defaults v1.7.0 h1:eNdqZvc5B509z18lD8yc212CAqJNvfT1Jq6L8WowdBA= github.com/creasty/defaults v1.7.0/go.mod h1:iGzKe6pbEHnpMPtfDXZEr0NVxWnPTjb1bbDy08fPzYM= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -16,17 +19,18 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davidmz/go-pageant v1.0.2 h1:bPblRCh5jGU+Uptpz6LgMZGD5hJoOt7otgT454WvHn0= github.com/davidmz/go-pageant v1.0.2/go.mod h1:P2EDDnMqIwG5Rrp05dTRITj9z2zpGcD9efWSkTNKLIE= -github.com/gofrs/uuid v4.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/go-logr/logr v1.3.0 h1:2y3SDp0ZXuc6/cjLSZ+Q3ir+QB9T/iG5yYRXqsagWSY= +github.com/go-logr/logr v1.3.0/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/gofrs/uuid v4.4.0+incompatible h1:3qXRTX8/NbyulANqlc0lchS1gqAVxRgsuW1YrTJupqA= github.com/gofrs/uuid v4.4.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= -github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= -github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= -github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI= github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= +github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= +github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8= github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= @@ -34,29 +38,26 @@ github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFK github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs= github.com/jcmturner/dnsutils/v2 v2.0.0 h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo= github.com/jcmturner/dnsutils/v2 v2.0.0/go.mod h1:b0TnjGOvI/n42bZa+hmXL+kFJZsFT7G4t3HTlQ184QM= -github.com/jcmturner/gofork v1.0.0/go.mod h1:MK8+TM0La+2rjBD4jE12Kj1pCCxK7d2LK/UM3ncEo0o= github.com/jcmturner/gofork v1.7.6 h1:QH0l3hzAU1tfT3rZCnW5zXl+orbkNMMRGJfdJjHVETg= github.com/jcmturner/gofork v1.7.6/go.mod h1:1622LH6i/EZqLloHfE7IeZ0uEJwMSUyQ/nDd82IeqRo= github.com/jcmturner/goidentity/v6 v6.0.1 h1:VKnZd2oEIMorCTsFBnJWbExfNN7yZr3EhJAxwOkZg6o= github.com/jcmturner/goidentity/v6 v6.0.1/go.mod h1:X1YW3bgtvwAXju7V3LCIMpY0Gbxyjn/mY9zx4tFonSg= -github.com/jcmturner/gokrb5/v8 v8.4.2/go.mod h1:sb+Xq/fTY5yktf/VxLsE3wlfPqQjp0aWNYyvBVK62bc= github.com/jcmturner/gokrb5/v8 v8.4.4 h1:x1Sv4HaTpepFkXbt2IkL29DXRf8sOfZXo8eRKh687T8= github.com/jcmturner/gokrb5/v8 v8.4.4/go.mod h1:1btQEpgT6k+unzCwX1KdWMEwPPkkgBtP+F6aCACiMrs= github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZY= github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc= -github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= -github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= github.com/kevinburke/ssh_config v1.2.0 h1:x584FjTGwHzMwvHx18PXxbBVzfnxogHaAReU4gf13a4= github.com/kevinburke/ssh_config v1.2.0/go.mod h1:CT57kijsi8u/K/BOFA39wgDQJ9CxiF4nAY/ojJ6r6mM= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/masterzen/simplexml v0.0.0-20190410153822-31eea3082786 h1:2ZKn+w/BJeL43sCxI2jhPLRv73oVVOjEKZjKkflyqxg= github.com/masterzen/simplexml v0.0.0-20190410153822-31eea3082786/go.mod h1:kCEbxUJlNDEBNbdQMkPSp6yaKcRXVI6f4ddk8Riv4bc= -github.com/masterzen/winrm v0.0.0-20220917170901-b07f6cb0598d h1:GXlX1g/AjI3/izilmeMvP/aHWYCuwOZXpJsS0XdGVls= -github.com/masterzen/winrm v0.0.0-20220917170901-b07f6cb0598d/go.mod h1:Iju3u6NzoTAvjuhsGCZc+7fReNnr/Bd6DsWj3WTokIU= +github.com/masterzen/winrm v0.0.0-20231128182143-52a9e15d5730 h1:zLVOWGRxX/IsRpqHjl0hjVq6BORcs7ubih+G2dGhTEs= +github.com/masterzen/winrm v0.0.0-20231128182143-52a9e15d5730/go.mod h1:qfAjztAGRm7J7Ci10OA9vrx8WRDM0mlhdsFu7gBtMK8= +github.com/mattn/go-shellwords v1.0.12 h1:M2zGm7EW6UQJvDeQxo4T51eKPurbeFbe8WtebGE2xrk= +github.com/mattn/go-shellwords v1.0.12/go.mod h1:EZzvwXDESEeg03EKmM+RmDnNOPKG4lLtQsUlTZDWQ8Y= github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -65,43 +66,37 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/tidwall/transform v0.0.0-20201103190739-32f242e2dbde h1:AMNpJRc7P+GTwVbl8DkK2I9I8BBUzNiHuH/tlxrpan0= +github.com/tidwall/transform v0.0.0-20201103190739-32f242e2dbde/go.mod h1:MvrEmduDUz4ST5pGZ7CABCnOU5f3ZiOAZzT6b1A6nX8= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20201112155050-0c6587e931a9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk= golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20211216030914-fe4d6282115f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= -golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= +golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= +golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI= golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -115,7 +110,6 @@ golang.org/x/term v0.15.0 h1:y/Oo/a/q3IXu26lQgl04j/gjuBDOBlx7X6Om1j2CPW4= golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= @@ -126,7 +120,6 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc golang.org/x/tools v0.9.3 h1:Gn1I8+64MsuTb/HpH+LmQtNas23LhUVr3rYZ0eKuaMM= golang.org/x/tools v0.9.3/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/localhost.go b/localhost.go index 96b2cb0f..ea4965df 100644 --- a/localhost.go +++ b/localhost.go @@ -11,7 +11,7 @@ import ( "sync" "github.com/k0sproject/rig/exec" - "github.com/kballard/go-shellquote" + "github.com/mattn/go-shellwords" ) const name = "[local] localhost" @@ -56,7 +56,7 @@ func (c *Localhost) Disconnect() {} // ExecStreams executes a command on the remote host and uses the passed in streams for stdin, stdout and stderr. It returns a Waiter with a .Wait() function that // blocks until the command finishes and returns an error if the exit code is not zero. -func (c *Localhost) ExecStreams(cmd string, stdin io.ReadCloser, stdout, stderr io.Writer, opts ...exec.Option) (waiter, error) { +func (c *Localhost) ExecStreams(cmd string, stdin io.ReadCloser, stdout, stderr io.Writer, opts ...exec.Option) (exec.Waiter, error) { execOpts := exec.Build(opts...) command, err := c.command(cmd, execOpts) if err != nil { @@ -179,7 +179,7 @@ func (c *Localhost) ExecInteractive(cmd string) error { Dir: cwd, } - parts, err := shellquote.Split(cmd) + parts, err := shellwords.Parse(cmd) if err != nil { return fmt.Errorf("failed to parse command: %w", err) } diff --git a/openssh.go b/openssh.go index 88d79d5a..17764bf4 100644 --- a/openssh.go +++ b/openssh.go @@ -323,7 +323,7 @@ func (c *OpenSSH) Exec(cmdStr string, opts ...exec.Option) error { //nolint:cycl } // ExecStreams executes a command on the remote host, streaming stdin, stdout and stderr -func (c *OpenSSH) ExecStreams(cmdStr string, stdin io.ReadCloser, stdout, stderr io.Writer, opts ...exec.Option) (waiter, error) { +func (c *OpenSSH) ExecStreams(cmdStr string, stdin io.ReadCloser, stdout, stderr io.Writer, opts ...exec.Option) (exec.Waiter, error) { if !c.DisableMultiplexing && !c.isConnected { return nil, ErrNotConnected } diff --git a/os/host.go b/os/host.go index 1ac4f3d0..dee7fc34 100644 --- a/os/host.go +++ b/os/host.go @@ -1,6 +1,8 @@ package os import ( + "io" + "github.com/k0sproject/rig/exec" ) @@ -11,6 +13,7 @@ type Host interface { ExecOutput(cmd string, opts ...exec.Option) (string, error) Execf(cmd string, argsOrOpts ...any) error ExecOutputf(cmd string, argsOrOpts ...any) (string, error) + ExecStreams(cmd string, stdin io.ReadCloser, stdout io.Writer, stderr io.Writer, opts ...exec.Option) (exec.Waiter, error) String() string Sudo(cmd string) (string, error) } diff --git a/os/linux.go b/os/linux.go index 318e8a95..74ddc285 100644 --- a/os/linux.go +++ b/os/linux.go @@ -2,6 +2,7 @@ package os import ( "bufio" + "bytes" "errors" "fmt" "io/fs" @@ -256,11 +257,15 @@ func (c Linux) InstallFile(h Host, src, dst, permissions string) error { // ReadFile reads a files contents from the host. func (c Linux) ReadFile(h Host, path string) (string, error) { - out, err := h.ExecOutputf("cat -- %s 2> /dev/null", shellescape.Quote(path), exec.HideOutput(), exec.Sudo(h)) + out := bytes.NewBuffer(nil) + cmd, err := h.ExecStreams(fmt.Sprintf("cat -- %s 2> /dev/null", shellescape.Quote(path)), nil, out, nil, exec.Sudo(h)) if err != nil { return "", fmt.Errorf("failed to read file %s: %w", path, err) } - return out, nil + if err := cmd.Wait(); err != nil { + return "", fmt.Errorf("failed to read file %s: %w", path, err) + } + return out.String(), nil } // DeleteFile deletes a file from the host. @@ -480,3 +485,12 @@ func (c Linux) Touch(h Host, path string, ts time.Time, opts ...exec.Option) err } return nil } + +// Sha256sum calculates the sha256 checksum of a file +func (c Linux) Sha256sum(h Host, path string, opts ...exec.Option) (string, error) { + out, err := h.ExecOutput(fmt.Sprintf("sha256sum -b -- %s 2> /dev/null", shellescape.Quote(path)), opts...) + if err != nil { + return "", fmt.Errorf("failed to shasum %s: %w", path, err) + } + return strings.Split(out, " ")[0], nil +} diff --git a/os/windows.go b/os/windows.go index 95a3dde7..2999c92d 100644 --- a/os/windows.go +++ b/os/windows.go @@ -2,6 +2,7 @@ package os import ( "bufio" + "errors" "fmt" "io/fs" "strconv" @@ -9,10 +10,11 @@ import ( "time" "github.com/k0sproject/rig/exec" - "github.com/k0sproject/rig/log" ps "github.com/k0sproject/rig/pkg/powershell" ) +var errNotSupported = errors.New("not supported on windows") + // Windows is the base package for windows OS support type Windows struct{} @@ -46,7 +48,7 @@ func (c Windows) InstallPackage(h Host, s ...string) error { // InstallFile on windows is a regular file move operation func (c Windows) InstallFile(h Host, src, dst, _ string) error { - if err := h.Execf("move /y %s %s", ps.DoubleQuote(src), ps.DoubleQuote(dst), exec.Sudo(h)); err != nil { + if err := h.Execf("move /y %s %s", ps.DoubleQuotePath(src), ps.DoubleQuotePath(dst), exec.Sudo(h)); err != nil { return fmt.Errorf("failed to move %s to %s: %w", src, dst, err) } return nil @@ -103,36 +105,17 @@ func (c Windows) SELinuxEnabled(_ Host) bool { // WriteFile writes file to host with given contents. Do not use for large files. // The permissions argument is ignored on windows. func (c Windows) WriteFile(h Host, path string, data string, _ string) error { - if data == "" { - return fmt.Errorf("%w: empty content for writing to %s", ErrCommandFailed, path) - } - - if path == "" { - return fmt.Errorf("%w: empty path for file writing %s", ErrCommandFailed, path) - } - - tempFile, err := h.ExecOutput("powershell -Command \"New-TemporaryFile | Write-Host\"") - if err != nil { - return fmt.Errorf("failed to create temporary file: %w", err) - } - defer c.deleteTempFile(h, tempFile) - - err = h.Exec(fmt.Sprintf(`powershell -Command "$Input | Out-File -FilePath %s"`, ps.SingleQuote(tempFile)), exec.Stdin(data), exec.RedactString(data)) - if err != nil { - return fmt.Errorf("failed to write to temporary file: %w", err) - } - - err = h.Exec(fmt.Sprintf(`powershell -Command "Move-Item -Force -Path %s -Destination %s"`, ps.SingleQuote(tempFile), ps.SingleQuote(path))) + err := h.Exec(fmt.Sprintf(`powershell -Command "$Input | Out-File -FilePath %s"`, ps.DoubleQuotePath(path)), exec.Stdin(data), exec.RedactString(data)) if err != nil { - return fmt.Errorf("failed to move temporary file to %s: %w", path, err) + return fmt.Errorf("failed to write to file %s: %w", path, err) } return nil } -// ReadFile reads a files contents from the host. +// ReadFile reads a file's contents from the host. func (c Windows) ReadFile(h Host, path string) (string, error) { - out, err := h.ExecOutput(fmt.Sprintf(`type %s`, ps.DoubleQuote(path)), exec.HideOutput()) + out, err := h.ExecOutput(fmt.Sprintf(`type %s`, ps.DoubleQuotePath(path)), exec.HideOutput()) if err != nil { return "", fmt.Errorf("failed to read file %s: %w", path, err) } @@ -141,21 +124,15 @@ func (c Windows) ReadFile(h Host, path string) (string, error) { // DeleteFile deletes a file from the host. func (c Windows) DeleteFile(h Host, path string) error { - if err := h.Exec(fmt.Sprintf(`del /f %s`, ps.DoubleQuote(path))); err != nil { + if err := h.Exec(fmt.Sprintf(`del /f %s`, ps.DoubleQuotePath(path))); err != nil { return fmt.Errorf("failed to delete file %s: %w", path, err) } return nil } -func (c Windows) deleteTempFile(h Host, path string) { - if err := c.DeleteFile(h, path); err != nil { - log.Debugf("failed to delete temporary file %s: %v", path, err) - } -} - // FileExist checks if a file exists on the host func (c Windows) FileExist(h Host, path string) bool { - return h.Exec(fmt.Sprintf(`powershell -Command "if (!(Test-Path -Path \"%s\")) { exit 1 }"`, path)) == nil + return h.Exec(fmt.Sprintf(`powershell -Command "if (!(Test-Path -Path \"%s\")) { exit 1 }"`, ps.DoubleQuotePath(path))) == nil } // UpdateEnvironment updates the hosts's environment variables @@ -209,7 +186,7 @@ func (c Windows) Reboot(h Host) error { // StartService starts a service func (c Windows) StartService(h Host, s string) error { - if err := h.Execf(`sc start "%s"`, s); err != nil { + if err := h.Execf(`sc start %s`, ps.DoubleQuote(s)); err != nil { return fmt.Errorf("failed to start service %s: %w", s, err) } return nil @@ -217,7 +194,7 @@ func (c Windows) StartService(h Host, s string) error { // StopService stops a service func (c Windows) StopService(h Host, s string) error { - if err := h.Execf(`sc stop "%s"`, s); err != nil { + if err := h.Execf(`sc stop %s`, ps.DoubleQuote(s)); err != nil { return fmt.Errorf("failed to stop service %s: %w", s, err) } return nil @@ -225,12 +202,12 @@ func (c Windows) StopService(h Host, s string) error { // ServiceScriptPath returns the path to a service configuration file func (c Windows) ServiceScriptPath(_ Host, _ string) (string, error) { - return "", fmt.Errorf("%w: service scripts not supported on windows", ErrCommandFailed) + return "", errNotSupported } // RestartService restarts a service func (c Windows) RestartService(h Host, s string) error { - if err := h.Execf(ps.Cmd(fmt.Sprintf(`Restart-Service "%s"`, s))); err != nil { + if err := h.Execf(ps.Cmd(fmt.Sprintf(`Restart-Service %s`, ps.DoubleQuote(s)))); err != nil { return fmt.Errorf("failed to restart service %s: %w", s, err) } return nil @@ -243,7 +220,7 @@ func (c Windows) DaemonReload(_ Host) error { // EnableService enables a service func (c Windows) EnableService(h Host, s string) error { - if err := h.Execf(`sc.exe config "%s" start=enabled`, s); err != nil { + if err := h.Execf(`sc.exe config %s start=enabled`, ps.DoubleQuote(s)); err != nil { return fmt.Errorf("failed to enable service %s: %w", s, err) } @@ -252,7 +229,7 @@ func (c Windows) EnableService(h Host, s string) error { // DisableService disables a service func (c Windows) DisableService(h Host, s string) error { - if err := h.Execf(`sc.exe config "%s" start=disabled`, s); err != nil { + if err := h.Execf(`sc.exe config %s start=disabled`, ps.DoubleQuote(s)); err != nil { return fmt.Errorf("failed to disable service %s: %w", s, err) } return nil @@ -260,7 +237,7 @@ func (c Windows) DisableService(h Host, s string) error { // ServiceIsRunning returns true if a service is running func (c Windows) ServiceIsRunning(h Host, s string) bool { - return h.Execf(`sc.exe query "%s" | findstr "RUNNING"`, s) == nil + return h.Execf(`sc.exe query %s | findstr "RUNNING"`, ps.DoubleQuote(s)) == nil } // MkDir creates a directory (including intermediate directories) @@ -281,7 +258,7 @@ func (c Windows) Chmod(_ Host, _, _ string, _ ...exec.Option) error { func (c Windows) Stat(h Host, path string, opts ...exec.Option) (*FileInfo, error) { info := &FileInfo{FName: path, FMode: fs.FileMode(0)} - out, err := h.ExecOutput(ps.Cmd(fmt.Sprintf("[System.Math]::Truncate((Get-Date -Date ((Get-Item %s).LastWriteTime.ToUniversalTime()) -UFormat %%s))", ps.DoubleQuote(path))), opts...) + out, err := h.ExecOutput(ps.Cmd(fmt.Sprintf("[System.Math]::Truncate((Get-Date -Date ((Get-Item -LiteralPath %s).LastWriteTime.ToUniversalTime()) -UFormat %%s))", ps.DoubleQuotePath(path))), opts...) if err != nil { return nil, fmt.Errorf("failed to get file %s modtime: %w", path, err) } @@ -291,7 +268,7 @@ func (c Windows) Stat(h Host, path string, opts ...exec.Option) (*FileInfo, erro } info.FModTime = time.Unix(ts, 0) - out, err = h.ExecOutput(ps.Cmd(fmt.Sprintf("(Get-Item %s).Length", ps.DoubleQuote(path))), opts...) + out, err = h.ExecOutput(ps.Cmd(fmt.Sprintf("(Get-Item -LiteralPath %s).Length", ps.DoubleQuotePath(path))), opts...) if err != nil { return nil, fmt.Errorf("failed to get file %s size: %w", path, err) } @@ -301,7 +278,7 @@ func (c Windows) Stat(h Host, path string, opts ...exec.Option) (*FileInfo, erro } info.FSize = size - out, err = h.ExecOutput(ps.Cmd(fmt.Sprintf("(Get-Item %s).GetType().Name", ps.DoubleQuote(path))), opts...) + out, err = h.ExecOutput(ps.Cmd(fmt.Sprintf("(Get-Item -LiteralPath %s).GetType().Name", ps.DoubleQuotePath(path))), opts...) if err != nil { return nil, fmt.Errorf("failed to get file %s type: %w", path, err) } @@ -313,12 +290,12 @@ func (c Windows) Stat(h Host, path string, opts ...exec.Option) (*FileInfo, erro // Touch updates a file's last modified time or creates a new empty file func (c Windows) Touch(h Host, path string, ts time.Time, opts ...exec.Option) error { if !c.FileExist(h, path) { - if err := h.Exec(ps.Cmd(fmt.Sprintf("Set-Content -Path %s -value $null", ps.DoubleQuote(path))), opts...); err != nil { + if err := h.Exec(ps.Cmd(fmt.Sprintf("Set-Content -LiteralPath %s -value $null", ps.DoubleQuotePath(path))), opts...); err != nil { return fmt.Errorf("failed to create file %s: %w", path, err) } } - err := h.Exec(ps.Cmd(fmt.Sprintf("(Get-Item %s).LastWriteTime = (Get-Date %s)", ps.DoubleQuote(path), ps.DoubleQuote(ts.Format(time.RFC3339)))), opts...) + err := h.Exec(ps.Cmd(fmt.Sprintf("(Get-Item -LiteralPath %s).LastWriteTime = (Get-Date %s)", ps.DoubleQuotePath(path), ps.DoubleQuote(ts.Format(time.RFC3339)))), opts...) if err != nil { return fmt.Errorf("failed to update file %s timestamp: %w", path, err) } @@ -355,3 +332,12 @@ func (c Windows) LineIntoFile(h Host, path, matcher, newLine string) error { return c.WriteFile(h, path, writer.String(), "0644") } + +// Sha256sum returns the sha256sum of a file +func (c Windows) Sha256sum(h Host, path string, opts ...exec.Option) (string, error) { + sum, err := h.ExecOutput(ps.Cmd(fmt.Sprintf("(Get-FileHash %s -Algorithm SHA256).Hash.ToLower()", ps.DoubleQuotePath(path))), opts...) + if err != nil { + return "", fmt.Errorf("failed to get sha256sum for %s: %w", path, err) + } + return strings.TrimSpace(sum), nil +} diff --git a/pkg/powershell/powershell.go b/pkg/powershell/powershell.go index 88861036..ebe5438c 100644 --- a/pkg/powershell/powershell.go +++ b/pkg/powershell/powershell.go @@ -87,7 +87,7 @@ func SingleQuote(v string) string { return buf.String() } -// DoubleQuote escapes a string in a way that can be used as a windows file path +// DoubleQuote adds double quotes around a string and escapes any double quotes inside. func DoubleQuote(v string) string { if v[0] == '"' && v[len(v)-1] == '"' { // already quoted @@ -107,3 +107,14 @@ func DoubleQuote(v string) string { _, _ = buf.WriteRune('"') return buf.String() } + +// DoubleQuotePath adds double quotes around a string and escapes any double quotes inside. +// It also converts forward slashes to backslashes. +func DoubleQuotePath(v string) string { + return DoubleQuote(ToWindowsPath(v)) +} + +// ToWindowsPath converts a unix-style forward slash separated path to a windows-style path +func ToWindowsPath(v string) string { + return strings.ReplaceAll(v, "/", "\\") +} diff --git a/pkg/rigfs/direntrybuffer.go b/pkg/rigfs/direntrybuffer.go new file mode 100644 index 00000000..b00cbd45 --- /dev/null +++ b/pkg/rigfs/direntrybuffer.go @@ -0,0 +1,50 @@ +package rigfs + +import ( + "io" + "io/fs" + "sort" +) + +type dirEntryBuffer struct { + entries []fs.DirEntry +} + +func newDirEntryBuffer(entries []fs.DirEntry) *dirEntryBuffer { + sort.Slice(entries, func(i, j int) bool { + isDirI, isDirJ := entries[i].IsDir(), entries[j].IsDir() + + // If both are directories or files, sort alphabetically + if isDirI == isDirJ { + return entries[i].Name() < entries[j].Name() + } + + // Otherwise, directories should come first + return isDirI + }) + + return &dirEntryBuffer{entries: entries} +} + +// Next returns the next n entries from the buffer. +// Subsequent calls on the same file will yield further DirEntry values. +// When there are no more entries, io.EOF is returned. +// A negative count returns all the remaining entries in the buffer. +func (b *dirEntryBuffer) Next(n int) ([]fs.DirEntry, error) { + if len(b.entries) == 0 { + return nil, io.EOF + } + + if n == 0 { + return nil, nil + } + + if n < 0 || n > len(b.entries) { + n = len(b.entries) + } + + // Retrieve the next n entries + entries := b.entries[:n] + b.entries = b.entries[n:] + return entries, nil +} diff --git a/pkg/rigfs/direntrybuffer_test.go b/pkg/rigfs/direntrybuffer_test.go new file mode 100644 index 00000000..3796147a --- /dev/null +++ b/pkg/rigfs/direntrybuffer_test.go @@ -0,0 +1,59 @@ +package rigfs + +import ( + "errors" + "io" + "io/fs" + "testing" +) + +func TestDirEntryBuffer(t *testing.T) { + // Create mock DirEntry slices for testing + mockEntries := []fs.DirEntry{ + mockDirEntry{name: "file1"}, + mockDirEntry{name: "file2"}, + mockDirEntry{name: "file3"}, + } + + // Test cases + tests := []struct { + name string + n int + initEntries []fs.DirEntry + wantEntries []int + wantErr []error + }{ + {"Empty Buffer", 1, []fs.DirEntry{}, []int{0}, []error{io.EOF}}, + {"Single Call", 2, mockEntries, []int{2}, []error{nil}}, + {"Multiple Calls", 1, mockEntries, []int{1, 1, 1, 0}, []error{nil, nil, nil, io.EOF}}, + {"Exact Count", 3, mockEntries, []int{3}, []error{nil}}, + {"Negative Count", -1, mockEntries, []int{3}, []error{nil}}, + {"End of Buffer", 10, mockEntries, []int{3, 0}, []error{nil, io.EOF}}, + {"Zero Count", 0, mockEntries, []int{0}, []error{nil}}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + buffer := newDirEntryBuffer(tc.initEntries) + for i, want := range tc.wantEntries { + entries, err := buffer.Next(tc.n) + if len(entries) != want { + t.Errorf("Call %d: got %d entries, want %d", i+1, len(entries), want) + } + if !errors.Is(err, tc.wantErr[i]) { + t.Errorf("Call %d: got error %v, want %v", i+1, err, tc.wantErr[i]) + } + } + }) + } +} + +// mockDirEntry is a mock implementation of fs.DirEntry for testing purposes. +type mockDirEntry struct { + name string +} + +func (m mockDirEntry) Name() string { return m.name } +func (m mockDirEntry) IsDir() bool { return false } +func (m mockDirEntry) Type() fs.FileMode { return 0 } +func (m mockDirEntry) Info() (fs.FileInfo, error) { return nil, nil } diff --git a/pkg/rigfs/posixfsys.go b/pkg/rigfs/posixfsys.go index 4e13514b..bf6b5a72 100644 --- a/pkg/rigfs/posixfsys.go +++ b/pkg/rigfs/posixfsys.go @@ -14,12 +14,14 @@ import ( "github.com/alessio/shellescape" "github.com/k0sproject/rig/exec" + "github.com/k0sproject/rig/log" ) var ( - _ fs.File = (*PosixFile)(nil) - _ fs.ReadDirFile = (*PosixDir)(nil) - _ fs.FS = (*PosixFsys)(nil) + _ fs.File = (*PosixFile)(nil) + _ fs.ReadDirFile = (*PosixDir)(nil) + _ fs.FS = (*PosixFsys)(nil) + errInvalid = errors.New("invalid") ) // PosixFsys implements fs.FS for a remote filesystem that uses POSIX commands for access @@ -42,8 +44,8 @@ const ( // PosixFile implements fs.File for a remote file type PosixFile struct { + withPath fsys *PosixFsys - path string isOpen bool isEOF bool pos int64 @@ -57,35 +59,19 @@ type PosixFile struct { // PosixDir implements fs.ReadDirFile for a remote directory type PosixDir struct { PosixFile - entries []fs.DirEntry - hw int + buffer *dirEntryBuffer } // ReadDir returns a list of directory entries func (f *PosixDir) ReadDir(n int) ([]fs.DirEntry, error) { - if n == 0 { - return f.fsys.ReadDir(f.path) - } - if f.entries == nil { + if f.buffer == nil { entries, err := f.fsys.ReadDir(f.path) if err != nil { return nil, err } - f.entries = entries - f.hw = 0 - } - if f.hw >= len(f.entries) { - return nil, io.EOF + f.buffer = newDirEntryBuffer(entries) } - var min int - if n > len(f.entries)-f.hw { - min = len(f.entries) - f.hw - } else { - min = n - } - old := f.hw - f.hw += min - return f.entries[old:f.hw], nil + return f.buffer.Next(n) } func (f *PosixFile) fsBlockSize() int { @@ -112,18 +98,18 @@ func (f *PosixFile) isWritable() bool { return f.isOpen && f.flags&os.O_WRONLY != 0 } -func (f *PosixFile) ddParams(offset int64, numBytes int) (int, int64, int) { - bs := f.fsBlockSize() +func (f *PosixFile) ddParams(offset int64, numBytes int) (blocksize int, skip int64, count int) { //nolint:nonamedreturns // for readability + optimalBs := f.fsBlockSize() - if numBytes < bs { - bs = numBytes - skip := offset / int64(bs) - return bs, skip, 1 + // if numBytes aligns with the optimal block size, use it; otherwise, use bs = 1 + bs := optimalBs + if numBytes%optimalBs != 0 { + bs = 1 } - skip := offset / int64(bs) - count := numBytes / bs - return bs, skip, count + s := offset / int64(bs) + c := (numBytes + bs - 1) / bs + return bs, s, c } // Stat returns a FileInfo describing the named file @@ -137,37 +123,47 @@ func (f *PosixFile) Read(p []byte) (int, error) { return 0, io.EOF } if !f.isReadable() { - return 0, fmt.Errorf("%w: file %s is not open for reading", ErrCommandFailed, f.path) + return 0, fmt.Errorf("%w: file %s is not open for reading", fs.ErrClosed, f.path) } - errbuf := bytes.NewBuffer(nil) bs, skip, count := f.ddParams(f.pos, len(p)) - toRead := bs * count - buf := bytes.NewBuffer(nil) - errbuf.Reset() - cmd, err := f.fsys.conn.ExecStreams(fmt.Sprintf("dd if=%s bs=%d skip=%d count=%d", shellescape.Quote(f.path), bs, skip, count), nil, buf, errbuf, f.fsys.opts...) + cmdStr := fmt.Sprintf("dd if=%s bs=%d skip=%d count=%d", shellescape.Quote(f.path), bs, skip, count) + + // Execute the command + buf := bytes.NewBuffer(nil) + errbuf := bytes.NewBuffer(nil) + cmd, err := f.fsys.conn.ExecStreams(cmdStr, nil, buf, errbuf, f.fsys.opts...) if err != nil { - return 0, fmt.Errorf("%w: failed to execute dd: %w (%s)", ErrCommandFailed, err, errbuf.String()) + return 0, fmt.Errorf("failed to execute dd: %w (%s)", err, errbuf.String()) } if err := cmd.Wait(); err != nil { - return 0, fmt.Errorf("%w: read (dd): %w (%s)", ErrCommandFailed, err, errbuf.String()) + return 0, fmt.Errorf("read (dd): %w (%s)", err, errbuf.String()) } - readBytes := copy(p, buf.Bytes()) - f.pos += int64(readBytes) - if readBytes < len(p) || readBytes < toRead { + readBytes := buf.Bytes() + + // Trim extra data if readBytes is larger than the requested size + if len(readBytes) > len(p) { + readBytes = readBytes[:len(p)] + } + + copied := copy(p, readBytes) + f.pos += int64(copied) + + if copied < len(p) { f.isEOF = true - return readBytes, io.EOF } - return readBytes, nil + return copied, nil } func (f *PosixFile) Write(p []byte) (int, error) { if !f.isWritable() { - return 0, fmt.Errorf("%w: file %s is not open for writing", ErrCommandFailed, f.path) + return 0, fmt.Errorf("%w: file %s is not open for writing", fs.ErrClosed, f.path) } + log.Debugf("writing %d bytes to %s", len(p), f.path) + var written int remaining := p for written < len(p) { @@ -176,12 +172,16 @@ func (f *PosixFile) Write(p []byte) (int, error) { errbuf := bytes.NewBuffer(nil) limitedReader := bytes.NewReader(remaining[:toWrite]) - cmd, err := f.fsys.conn.ExecStreams(fmt.Sprintf("dd if=/dev/stdin of=%s bs=%d count=%d seek=%d conv=notrunc", f.path, bs, count, skip), io.NopCloser(limitedReader), io.Discard, errbuf, f.fsys.opts...) + cmd, err := f.fsys.conn.ExecStreams( + fmt.Sprintf("dd if=/dev/stdin of=%s bs=%d count=%d seek=%d conv=notrunc", f.path, bs, count, skip), + io.NopCloser(limitedReader), io.Discard, errbuf, + f.fsys.opts..., + ) if err != nil { - return 0, fmt.Errorf("%w: write (dd): %w", ErrCommandFailed, err) + return 0, fmt.Errorf("write (dd): %w", err) } if err := cmd.Wait(); err != nil { - return 0, fmt.Errorf("%w: write (dd): %w (%s)", ErrCommandFailed, err, errbuf.String()) + return 0, fmt.Errorf("write (dd): %w (%s)", err, errbuf.String()) } written += toWrite @@ -199,68 +199,22 @@ func (f *PosixFile) Write(p []byte) (int, error) { return written, nil } -// CopyFromN copies n bytes from the remote file. The alt writer can be used for progress -// tracking, use nil when not needed. -func (f *PosixFile) CopyFromN(src io.Reader, num int64, alt io.Writer) (int64, error) { - if !f.isWritable() { - return 0, fmt.Errorf("%w: file %s is not open for writing", ErrCommandFailed, f.path) - } - var ddCmd string - if f.pos+num >= f.size { - // truncate to current position - if err := f.fsys.Truncate(f.path, f.pos); err != nil { - return 0, fmt.Errorf("%w: truncate %s for writing: %w", ErrCommandFailed, f.path, err) - } - ddCmd = fmt.Sprintf("dd if=/dev/stdin of=%s bs=16M oflag=append conv=notrunc", shellescape.Quote(f.path)) - } else { - ddCmd = fmt.Sprintf("dd if=/dev/stdin of=%s bs=1 seek=%d conv=notrunc", shellescape.Quote(f.path), f.pos) - } - limited := io.LimitReader(src, num) - var reader io.Reader - if alt != nil { - reader = io.TeeReader(limited, alt) - } else { - reader = limited - } - - errbuf := bytes.NewBuffer(nil) - cmd, err := f.fsys.conn.ExecStreams(ddCmd, io.NopCloser(reader), io.Discard, errbuf, f.fsys.opts...) - if err != nil { - return 0, fmt.Errorf("%w: failed to execute dd (copy-from): %w (%s)", ErrCommandFailed, err, errbuf.String()) - } - if err != nil { - return 0, fmt.Errorf("%w: copy-from: %w", ErrCommandFailed, err) - } - f.pos += num - if f.pos >= f.size { - f.isEOF = true - f.size = f.pos - } - if err != nil { - return 0, &fs.PathError{Op: "copy-from", Path: f.path, Err: fmt.Errorf("%w: error while copying: %w", ErrRcpCommandFailed, err)} - } - if err := cmd.Wait(); err != nil { - return 0, &fs.PathError{Op: "copy-from", Path: f.path, Err: fmt.Errorf("%w: error while copying: %w (%s)", ErrRcpCommandFailed, err, errbuf.String())} - } - return num, nil -} - // Copy copies the remote file at src to the local file at dst func (f *PosixFile) Copy(dst io.Writer) (int64, error) { if f.isEOF { return 0, io.EOF } if !f.isReadable() { - return 0, fmt.Errorf("%w: file %s is not open for reading", ErrCommandFailed, f.path) + return 0, f.pathErr("copy", fmt.Errorf("%w: file %s is not open for reading", fs.ErrClosed, f.path)) } bs, skip, count := f.ddParams(f.pos, int(f.size-f.pos)) errbuf := bytes.NewBuffer(nil) cmd, err := f.fsys.conn.ExecStreams(fmt.Sprintf("dd if=%s bs=%d skip=%d count=%d", shellescape.Quote(f.path), bs, skip, count), nil, dst, errbuf, f.fsys.opts...) if err != nil { - return 0, fmt.Errorf("%w: failed to execute dd (copy): %w (%s)", ErrCommandFailed, err, errbuf.String()) + return 0, f.pathErr("copy", fmt.Errorf("failed to execute dd: %w (%s)", err, errbuf.String())) } if err := cmd.Wait(); err != nil { - return 0, fmt.Errorf("%w: copy (dd): %w (%s)", ErrCommandFailed, err, errbuf.String()) + return 0, f.pathErr("copy", fmt.Errorf("dd: %w (%s)", err, errbuf.String())) } f.pos = f.size f.isEOF = true @@ -274,9 +228,9 @@ func (f *PosixFile) Close() error { } // Seek sets the offset for the next Read or Write to offset, interpreted according to whence: -// 0 means relative to the origin of the file, -// 1 means relative to the current offset, and -// 2 means relative to the end. +// io.SeekStart means relative to the origin of the file, +// io.SeekCurrent means relative to the current offset, and +// io.SeekEnd means relative to the end. // Seek returns the new offset relative to the start of the file and an error, if any. func (f *PosixFile) Seek(offset int64, whence int) (int64, error) { switch whence { @@ -287,7 +241,7 @@ func (f *PosixFile) Seek(offset int64, whence int) (int64, error) { case io.SeekEnd: f.pos = f.size + offset default: - return 0, fmt.Errorf("%w: invalid whence: %d", ErrCommandFailed, whence) + return 0, fmt.Errorf("%w: whence: %d", errInvalid, whence) } f.isEOF = f.pos >= f.size @@ -306,7 +260,7 @@ func (fsys *PosixFsys) initStat() error { opts = append(opts, exec.HideOutput()) out, err := fsys.conn.ExecOutput("stat --help 2>&1", opts...) if err != nil { - return fmt.Errorf("%w: can't access stat command: %w", ErrCommandFailed, err) + return fmt.Errorf("can't access stat command: %w", err) } if strings.Contains(out, "BusyBox") || strings.Contains(out, "--format=") { fsys.statCmd = &statCmdGNU @@ -358,7 +312,7 @@ func (fsys *PosixFsys) parseStat(stat string) (*FileInfo, error) { // output looks like: 0x81a4 0 1699970097.220228000 //test_20231114155456.txt// parts := strings.SplitN(stat, " ", 4) if len(parts) != 4 { - return nil, fmt.Errorf("%w: stat parse output %s", ErrCommandFailed, stat) + return nil, fmt.Errorf("%w: parse stat output %s", errInvalid, stat) } res := &FileInfo{fsys: fsys} @@ -366,13 +320,13 @@ func (fsys *PosixFsys) parseStat(stat string) (*FileInfo, error) { if strings.HasPrefix(parts[0], "0x") { m, err := strconv.ParseInt(parts[0][2:], 16, 64) if err != nil { - return nil, fmt.Errorf("%w: stat parse mode %s: %w", ErrCommandFailed, stat, err) + return nil, fmt.Errorf("parse stat mode %s: %w", stat, err) } res.FMode = posixBitsToFileMode(m) } else { m, err := strconv.ParseInt(parts[0], 8, 64) if err != nil { - return nil, fmt.Errorf("%w: stat parse mode %s: %w", ErrCommandFailed, stat, err) + return nil, fmt.Errorf("parse stat mode %s: %w", stat, err) } res.FMode = posixBitsToFileMode(m) } @@ -381,20 +335,20 @@ func (fsys *PosixFsys) parseStat(stat string) (*FileInfo, error) { size, err := strconv.ParseInt(parts[1], 10, 64) if err != nil { - return nil, fmt.Errorf("%w: stat parse size %s: %w", ErrCommandFailed, stat, err) + return nil, fmt.Errorf("parse stat size %s: %w", stat, err) } res.FSize = size timeParts := strings.SplitN(parts[2], ".", 2) mtime, err := strconv.ParseInt(timeParts[0], 10, 64) if err != nil { - return nil, fmt.Errorf("%w: stat parse mtime %s: %w", ErrCommandFailed, stat, err) + return nil, fmt.Errorf("parse stat mtime %s: %w", stat, err) } var mtimeNano int64 if len(timeParts) == 2 { mtimeNano, err = strconv.ParseInt(timeParts[1], 10, 64) if err != nil { - return nil, fmt.Errorf("%w: stat parse mtime ns %s: %w", ErrCommandFailed, stat, err) + return nil, fmt.Errorf("parse stat mtime ns %s: %w", stat, err) } } res.FModTime = time.Unix(mtime, mtimeNano) @@ -426,9 +380,9 @@ func (fsys *PosixFsys) multiStat(names ...string) ([]fs.FileInfo, error) { out, err := fsys.conn.ExecOutput(fmt.Sprintf(*fsys.statCmd, batch.String()), fsys.opts...) if err != nil { if len(names) == 1 { - return nil, &fs.PathError{Op: "stat", Path: names[0], Err: fs.ErrNotExist} + return nil, &fs.PathError{Op: OpStat, Path: names[0], Err: fs.ErrNotExist} } - return nil, fmt.Errorf("%w: stat %s: %w", ErrCommandFailed, names, err) + return nil, fmt.Errorf("stat %s: %w", names, err) } lines := strings.Split(out, "\n") for _, line := range lines { @@ -453,11 +407,11 @@ func (fsys *PosixFsys) Stat(name string) (fs.FileInfo, error) { } switch len(items) { case 0: - return nil, &fs.PathError{Op: "stat", Path: name, Err: fs.ErrNotExist} + return nil, &fs.PathError{Op: OpStat, Path: name, Err: fs.ErrNotExist} case 1: return items[0], nil default: - return nil, fmt.Errorf("%w: stat %s: too many results", ErrCommandFailed, name) + return nil, fmt.Errorf("%w: stat %s: too many results", errInvalid, name) } } @@ -468,11 +422,11 @@ func (fsys *PosixFsys) Sha256(name string) (string, error) { if isNotExist(err) { return "", &fs.PathError{Op: "sha256sum", Path: name, Err: fs.ErrNotExist} } - return "", fmt.Errorf("%w: sha256sum %s: %w", ErrCommandFailed, name, err) + return "", fmt.Errorf("sha256sum %s: %w", name, err) } sha := strings.Fields(out)[0] if len(sha) != 64 { - return "", fmt.Errorf("%w: sha256sum invalid output %s: %s", ErrCommandFailed, name, out) + return "", fmt.Errorf("%w: sha256sum invalid output %s: %s", errInvalid, name, out) } return sha, nil } @@ -481,7 +435,7 @@ func (fsys *PosixFsys) Sha256(name string) (string, error) { func (fsys *PosixFsys) Touch(name string) error { err := fsys.conn.Exec(fmt.Sprintf("touch %s", shellescape.Quote(name)), fsys.opts...) if err != nil { - return fmt.Errorf("%w: touch %s: %w", ErrCommandFailed, name, err) + return fmt.Errorf("touch %s: %w", name, err) } return nil } @@ -495,7 +449,7 @@ func (fsys *PosixFsys) secTouchT(name string, t time.Time) error { shellescape.Quote(name), ) if err := fsys.conn.Exec(cmd, fsys.opts...); err != nil { - return fmt.Errorf("%w: touch %s: %w", ErrCommandFailed, name, err) + return fmt.Errorf("touch %s: %w", name, err) } return nil } @@ -510,7 +464,7 @@ func (fsys *PosixFsys) nsecTouchT(name string, t time.Time) error { shellescape.Quote(name), ) if err := fsys.conn.Exec(cmd, fsys.opts...); err != nil { - return fmt.Errorf("%w: touch (ns) %s: %w", ErrCommandFailed, name, err) + return fmt.Errorf("touch (ns) %s: %w", name, err) } return nil } @@ -531,7 +485,7 @@ func (fsys *PosixFsys) TouchT(name string, t time.Time) error { // Truncate changes the size of the named file or creates a new file if it doesn't exist func (fsys *PosixFsys) Truncate(name string, size int64) error { if err := fsys.conn.Exec(fmt.Sprintf("truncate -s %d %s", size, shellescape.Quote(name)), fsys.opts...); err != nil { - return fmt.Errorf("%w: truncate %s: %w", ErrCommandFailed, name, err) + return fmt.Errorf("truncate %s: %w", name, err) } return nil } @@ -542,7 +496,7 @@ func (fsys *PosixFsys) Chmod(name string, mode fs.FileMode) error { if isNotExist(err) { return &fs.PathError{Op: "chmod", Path: name, Err: fs.ErrNotExist} } - return fmt.Errorf("%w: chmod %s: %w", ErrCommandFailed, name, err) + return fmt.Errorf("chmod %s: %w", name, err) } return nil } @@ -552,72 +506,80 @@ func (fsys *PosixFsys) Open(name string) (fs.File, error) { return fsys.OpenFile(name, os.O_RDONLY, 0) } +func (fsys *PosixFsys) openNew(name string, flags int, perm fs.FileMode) (fs.FileInfo, error) { + if flags&os.O_CREATE == 0 { + return nil, &fs.PathError{Op: OpOpen, Path: name, Err: fs.ErrNotExist} + } + + if _, err := fsys.Stat(filepath.Dir(name)); err != nil { + if errors.Is(err, fs.ErrNotExist) { + return nil, &fs.PathError{Op: OpOpen, Path: name, Err: fmt.Errorf("%w: parent directory does not exist", fs.ErrNotExist)} + } + return nil, &fs.PathError{Op: OpOpen, Path: name, Err: fmt.Errorf("%w: failed to stat parent directory", fs.ErrInvalid)} + } + + if err := fsys.conn.Exec(fmt.Sprintf("install -m %#o /dev/null %s", perm, shellescape.Quote(name)), fsys.opts...); err != nil { + return nil, &fs.PathError{Op: OpOpen, Path: name, Err: err} + } + + // re-stat to ensure file is now there and get the correct bits if there's a umask + return fsys.Stat(name) +} + +func (fsys *PosixFsys) openExisting(name string, flags int, info fs.FileInfo) (fs.FileInfo, error) { + // directories can't be opened for writing + if info.IsDir() && flags&(os.O_WRONLY|os.O_RDWR|os.O_CREATE|os.O_EXCL) != 0 { + return nil, &fs.PathError{Op: OpOpen, Path: name, Err: fmt.Errorf("%w: is a directory", fs.ErrInvalid)} + } + + // if O_CREATE and O_EXCL are set, the file must not exist + if flags&(os.O_CREATE|os.O_EXCL) == (os.O_CREATE | os.O_EXCL) { + return nil, &fs.PathError{Op: OpOpen, Path: name, Err: fs.ErrExist} + } + + if flags&os.O_TRUNC != 0 { + if err := fsys.Truncate(name, 0); err != nil { + return nil, err + } + } + + return fsys.Stat(name) +} + // OpenFile is used to open a file with access/creation flags for reading or writing. For info on flags, // see https://pkg.go.dev/os#pkg-constants -func (fsys *PosixFsys) OpenFile(name string, flags int, perm fs.FileMode) (File, error) { //nolint:cyclop +func (fsys *PosixFsys) OpenFile(name string, flags int, perm fs.FileMode) (File, error) { if flags&^supportedFlags != 0 { - return nil, fmt.Errorf("%w: unsupported flags: %d", ErrCommandFailed, flags) + return nil, fmt.Errorf("%w: unsupported flags: %d", errInvalid, flags) } - var pos int64 info, err := fsys.Stat(name) - if err != nil && !errors.Is(err, fs.ErrNotExist) { - return nil, err + if err != nil { + if !errors.Is(err, fs.ErrNotExist) { + return nil, err + } + info, err = fsys.openNew(name, flags, perm) + } else { + info, err = fsys.openExisting(name, flags, info) } - fileExists := err == nil - - switch fileExists { - case false: - switch flags & os.O_CREATE { - case 0: - return nil, &fs.PathError{Op: "open", Path: name, Err: fs.ErrNotExist} - default: - if _, err := fsys.Stat(filepath.Dir(name)); err != nil { - if errors.Is(err, fs.ErrNotExist) { - return nil, &fs.PathError{Op: "open", Path: name, Err: fmt.Errorf("%w: parent directory does not exist", fs.ErrNotExist)} - } - return nil, &fs.PathError{Op: "open", Path: name, Err: fmt.Errorf("%w: failed to stat parent directory", fs.ErrInvalid)} - } - if err := fsys.conn.Exec(fmt.Sprintf("install -m %#o /dev/null %s", perm, shellescape.Quote(name)), fsys.opts...); err != nil { - return nil, &fs.PathError{Op: "open", Path: name, Err: err} - } + if err != nil { + return nil, err + } - // re-stat to ensure file is now there and get the correct bits if there's a umask - i, err := fsys.Stat(name) - if err != nil { - return nil, err - } - info = i - } - case true: - switch { - case info.IsDir(): - return nil, &fs.PathError{Op: "open", Path: name, Err: fmt.Errorf("%w: is a directory", fs.ErrInvalid)} - case flags&(os.O_CREATE|os.O_EXCL) == (os.O_CREATE | os.O_EXCL): - return nil, &fs.PathError{Op: "open", Path: name, Err: fs.ErrExist} - case flags&os.O_TRUNC != 0: - if err := fsys.Truncate(name, 0); err != nil { - return nil, err - } - i, err := fsys.Stat(name) - if err != nil { - return nil, err - } - info = i - case flags&os.O_APPEND != 0: - pos = info.Size() - } + var pos int64 + if flags&os.O_APPEND != 0 { + pos = info.Size() } file := &PosixFile{ - fsys: fsys, - path: name, - isOpen: true, - size: info.Size(), - pos: pos, - mode: info.Mode(), - flags: flags, + withPath: withPath{name}, + fsys: fsys, + isOpen: true, + size: info.Size(), + pos: pos, + mode: info.Mode(), + flags: flags, } if info.IsDir() { return &PosixDir{PosixFile: *file}, nil @@ -635,9 +597,9 @@ func (fsys *PosixFsys) ReadDir(name string) ([]fs.DirEntry, error) { name = "." } - out, err := fsys.conn.ExecOutput(fmt.Sprintf("find %[1]s -maxdepth 1 -print0 | sort -z", shellescape.Quote(name)), fsys.opts...) + out, err := fsys.conn.ExecOutput(fmt.Sprintf("find %s -maxdepth 1 -print0", shellescape.Quote(name)), fsys.opts...) if err != nil { - return nil, fmt.Errorf("%w: read dir (find) %s: %w", ErrCommandFailed, name, err) + return nil, fmt.Errorf("read dir (find) %s: %w", name, err) } items := strings.Split(out, "\x00") if len(items) == 0 || (len(items) == 1 && items[0] == "") { @@ -653,11 +615,9 @@ func (fsys *PosixFsys) ReadDir(name string) ([]fs.DirEntry, error) { res := make([]fs.DirEntry, 0, len(items)-1) infos, err := fsys.multiStat(items[1:]...) for _, entry := range infos { - info, ok := entry.(fs.DirEntry) - if !ok { - return res, fmt.Errorf("%w: read dir: entry is not a FileInfo %s", ErrCommandFailed, name) + if info, ok := entry.(fs.DirEntry); ok { + res = append(res, info) } - res = append(res, info) } return res, err } @@ -665,7 +625,7 @@ func (fsys *PosixFsys) ReadDir(name string) ([]fs.DirEntry, error) { // Remove deletes the named file or (empty) directory. func (fsys *PosixFsys) Remove(name string) error { if err := fsys.conn.Exec(fmt.Sprintf("rm -f %s", shellescape.Quote(name)), fsys.opts...); err != nil { - return fmt.Errorf("%w: delete %s: %w", ErrCommandFailed, name, err) + return fmt.Errorf("delete %s: %w", name, err) } return nil } @@ -677,7 +637,7 @@ func isNotExist(err error) bool { // RemoveAll removes path and any children it contains. func (fsys *PosixFsys) RemoveAll(name string) error { if err := fsys.conn.Exec(fmt.Sprintf("rm -rf %s", shellescape.Quote(name)), fsys.opts...); err != nil { - return fmt.Errorf("%w: remove all %s: %w", ErrCommandFailed, name, err) + return fmt.Errorf("remove all %s: %w", name, err) } return nil } @@ -690,11 +650,11 @@ func (fsys *PosixFsys) MkDirAll(name string, perm fs.FileMode) error { if existing.IsDir() { return nil } - return fmt.Errorf("%w: mkdir %s: %w", ErrCommandFailed, name, fs.ErrExist) + return fmt.Errorf("mkdir %s: %w", name, fs.ErrExist) } if err := fsys.conn.Exec(fmt.Sprintf("install -d -m %#o %s", perm, shellescape.Quote(dir)), fsys.opts...); err != nil { - return fmt.Errorf("%w: mkdir %s: %w", ErrCommandFailed, name, err) + return fmt.Errorf("mkdir %s: %w", name, err) } return nil diff --git a/pkg/rigfs/rigrcp.ps1 b/pkg/rigfs/rigrcp.ps1 index 1f2ec974..f9ad989c 100644 --- a/pkg/rigfs/rigrcp.ps1 +++ b/pkg/rigfs/rigrcp.ps1 @@ -1,93 +1,48 @@ begin { Set-Alias NO New-Object - - class Stat { - [int]$size - [int]$mode - [int]$unixMode - [int]$modTime - [bool]$isDir - [string]$name - Stat([IO.FileSystemInfo]$fi){ - if($fi.Exists -eq $false){ - throw ("file not found") - } - $this.isDir=($fi.Attributes -band [IO.FileAttributes]::Directory) - $this.modTime=[int](Get-Date ($fi.LastWriteTimeUtc).ToUniversalTime() -UFormat %s) - $this.size=[int]$fi.Length - $this.unixMode=[int]$fi.UnixFileMode - $this.mode=[int]$fi.Attributes - $this.name=$fi.FullName - } - } - class FileContext { - [IO.FileStream]$f - [bool]$EOF - - FileContext([IO.FileStream]$f){ - $this.f=$f + function Close-Dispose($obj){ + if ($obj -ne $null){ + $obj.Close() + $obj.Dispose() } } - class FM { - hidden [hashtable]$f = @{} - - [string] Add($ctx) { - $id=[guid]::NewGuid().ToString() - $this.f[$id] = $ctx - return $id - } - - [FileContext] Get([string]$id) { - if ($this.f.ContainsKey($id)) { - return $this.f[$id] - } else { - throw "file not open" - } - } - - [void] Del([string]$id) { - if ($this.f.ContainsKey($id)) { - $v=$this.f[$id].f - $this.f.Remove($id) - $v.Close() - $v.Dispose() - $v=$null - } - } - - [void] Close() { - $this.f.Values|ForEach-Object { - $this.Del($_) - } - } + function Emit($s, $obj){ + $j=ConvertTo-Json -InputObject $obj -Depth 10 -Compress + $b=[System.Text.Encoding]::UTF8.GetBytes($j + [char]0) + $s.Write($b, 0, $b.Length) + $s.Flush() } - function FSInfo($path){ - try { - $p=Resolve-Path $path - } catch { - if(![IO.Path]::IsPathRooted($path)){ - $p=Join-Path $pwd $path + function Invoke-WithRetry($Script, $Retries, $Match) { + for ($i = 1; $i -le $Retries; $i++) { + try { + &$Script + return + } catch { + if ($_.Exception.Message -like $Match) { + Write-Warning "retrying" + Start-Sleep -Seconds 1 + } else { + throw + } + } } - } - - if(Test-Path $p -PathType Container){ - return (NO IO.DirectoryInfo($p)) - } - return (NO IO.FileInfo($p)) } - function Emit($s, $obj){ - $j=ConvertTo-Json -InputObject $obj -Depth 10 -Compress - $b=[System.Text.Encoding]::UTF8.GetBytes($j) - $s.Write($b, 0, $b.Length) - $s.WriteByte(0) - } + function HexDump { + # Assuming the first argument to the function is the buffer + $buffer = $args[0] - $bufSize=32768 - + for ($i = 0; $i -lt $buffer.Length; $i += 16) { + $line = $buffer[$i..([Math]::Min($i + 15, $buffer.Length - 1))] + $hex = ($line | ForEach-Object { "{0:X2}" -f $_ }) -join " " + $text = ($line | ForEach-Object { if ($_ -ge 32 -and $_ -le 126) {[char]$_} else {"."} }) -join "" + $offset = "{0:X8}" -f $i + "${offset}: $hex $text" + } + } $DebugPreference="Continue" $ErrorActionPreference="Stop" $ProgressPreference="SilentlyContinue" @@ -101,216 +56,145 @@ public static extern IntPtr GetStdHandle(int nStdHandle); '@ $Kernel32=Add-Type -MemberDefinition $MethodDefinitions -Name 'Kernel32' -Namespace 'Win32' -PassThru $inHandle=$Kernel32::GetStdHandle(-10) + $outHandle = $Kernel32::GetStdHandle(-11) + $in=NO IO.FileStream $inHandle, ([IO.FileAccess]::Read), ([IO.FileShare]::Read), 16384, $false $inStream=NO IO.StreamReader $in - [Console]::OutputEncoding=[System.Text.Encoding]::ASCII - $out=[System.Console]::OpenStandardOutput() + $out = New-Object IO.FileStream $outHandle, ([IO.FileAccess]::Write), ([IO.FileShare]::Write), 16384, $false - $buf=NO byte[] $bufSize - # create a file manager - $fm=[FM]::new() + $f=$null + $quit=$false + $p=$null - while(!$inStream.EndOfStream){ + while(!$inStream.EndOfStream -And !$quit){ try { $command=$inStream.ReadLine() $arg=$command -split " " switch ($arg[0]){ - 'stat' { - $p=$arg[1..($arg.Length-1)] -join " " - $fi=FSInfo $p - $info=NO Stat $fi - $o=@{ - stat=$info - } - Emit $out $o - } - 'dir' { - $p=$arg[1..($arg.Length-1)] -join " " - $di=FSInfo $p - if($di -is [IO.FileInfo]){ - throw "not a directory" - } - if(!$di.Exists){ - throw "directory not found" - } + # open + 'o' { + if ($f -ne $null){ throw "file already open" } + $script:mode=$arg[1] + $script:access=$arg[2] + $path=$arg[3..($arg.Length-1)] -join " " try { - $di.GetAccessControl()|Out-Null + $p=Resolve-Path $path } catch { - throw "access denied" - } - $infos=@() - $di.GetFileSystemInfos()|ForEach-Object { - $info=NO Stat $_ - $infos += $info - } - $o=@{ - dir=$infos + if(![IO.Path]::IsPathRooted($path)){ + $p=Join-Path $pwd $path + } } - Emit $out $o - } - # open - 'o' { - $mode=$arg[1] - $p=$arg[2..($arg.Length-1)] -join " " - $fi=FSInfo $p - $p=$fi.FullName - $fmode=$null - $m=[IO.FileMode] - switch ($mode){ - 'ro' {$fmode=$m::Open} - 'w' {$fmode=$m::CreateNew} - 'rw' {$fmode=$m::ReadWrite} - 'c' {if($fi.Exists){$fmode=$m::Truncate} else {$fmode=$m::Create}} - 'a' {$fmode=$m::Append} - default {throw "invalid mode"} + if(Test-Path $p -PathType Container){ + throw "cannot open directory" } - - $f=NO IO.FileStream($p, $fmode) - $ctx=[FileContext]::new($f) - $id=$fm.Add($ctx) - $props=@{ - id=$id - pos=$ctx.f.Position - eof=$ctx.EOF - name=$ctx.f.Name + $script:fi=NO IO.FileInfo($p) + $script:f=$null + Invoke-WithRetry -Retries 10 -Match "*used by another process*" -Script { + $script:f=$script:fi.Open($script:mode, $script:access) + } + $f=$script:f + $script:f=$null + if ($f -eq $null){ + throw "file not opened" } + $pos=$f.Position $o=@{ - open=$props + pos=$pos } Emit $out $o } # seek - 'seek' { - $ctx=$fm.Get($arg[1]) - $f=$ctx.f - $whence=[int]$arg[2] - $pos=$arg[3] - $cp=$f.Position - switch ($whence){ - 0 {$cp=$pos} - 1 {$cp+=$pos} - 2 {$cp=$f.Length-[Math]::Abs($pos)} - default { - throw "invalid whence" - } - } - $f.Position=$cp - $ctx.EOF=$cp -ge $f.Length - $props=@{ - position=[int64]$cp - } + 's' { + if ($f -eq $null){ throw "file not open" } + $pos=$arg[1] + $whence=$arg[2] + $pos=$f.Seek($pos, $whence) $o=@{ - seek=$props + pos=$pos } Emit $out $o } # read 'r' { - $ctx=$fm.Get($arg[1]) - if($ctx.EOF){ - throw "eof" - } - if(-not ($f -is [IO.FileStream] -and $f.CanRead)){ - throw "file not open for writing" - } - - $cnt=[int]$arg[2] - if($cnt -eq 0){ - throw "zero count" - } - - $f=$ctx.f - + if ($f -eq $null){ throw "file not open" } + $cnt=[int]$arg[1] if($cnt -eq -1){ $total=$f.Length - $f.Position $pos=$f.Length - $props=@{ - bytes=$total - } $o=@{ - read=$props - } + n=$total + } Emit $out $o $out.Flush() $f.CopyTo($out) $out.Flush() - $ctx.EOF=$true continue } - if($cnt -gt $bufSize){ - throw ("count exceeds buffer size "+$bufSize) - } - if($f.EndOfStream){ - $ctx.EOF=$true throw "eof" } - - $b=$f.Read($buf, 0, $count) - + + $buf=NO byte[] $cnt + $b=$f.Read($buf, 0, $cnt) if($b -eq 0){ - $ctx.EOF=$true throw "eof" } - $props=@{ - bytes=$b - } $o=@{ - read=$props + n=$b } Emit $out $o $out.Flush() $out.Write($buf, 0, $b) $out.Flush() + $buf=$null } # write 'w' { - $ctx=$fm.Get($arg[1]) - $f=$ctx.f - if(-not ($f -is [IO.FileStream] -and $f.CanWrite)){ - throw "file not open for writing" - } - - $count=[int]$arg[2] - if($count -eq 0){ - throw "zero count" - } - $out.WriteByte(0) - - $rt=0 - while ($rt -lt $count){ - $toRead=[Math]::Min($bufSize, $count - $rt) - $r=$in.Read($buf, 0, $toRead) - $f.Write($buf, 0, $r) - $rt += $r + if ($f -eq $null){ throw "file not open" } + $cnt=[int]$arg[1] + $o=@{ + n=$cnt } - $ctx.EOF=$true + Emit $out $o + $out.Flush() + $buf=NO byte[] $cnt + $b=$in.Read($buf, 0, $cnt) + if($b -ne $cnt){ + $dump=HexDump $buf + throw "short read $b bytes instead of $cnt bytes\n$dump" + } + $f.Write($buf, 0, $b) + $f.Flush() + $buf=$null } - # close 'c' { - $fm.Del($arg[1]) - $out.WriteByte(0) + if ($f -eq $null){ throw "file not open" } + Close-Dispose $f + $f=$null + $o=@{ + pos=-1 + } + Emit $out $o } 'q' { - throw "quit" + $quit=$true + Close-Dipose $out + Close-Dipose $inStream + Close-Dipose $in } default { throw "invalid command" } } } catch { - if($_.Exception.Message -eq "quit"){ - break - } $msg=@{ error=$_.Exception.Message } Emit $out $msg } } - $fm.Close() } diff --git a/pkg/rigfs/types.go b/pkg/rigfs/types.go index 975953e0..980bb4e8 100644 --- a/pkg/rigfs/types.go +++ b/pkg/rigfs/types.go @@ -11,25 +11,19 @@ import ( // ErrCommandFailed is returned when a remote command fails var ErrCommandFailed = errors.New("command failed") -// Waiter is an interface that has a Wait() function that blocks until a command is finished -type Waiter interface { - Wait() error -} - type connection interface { IsWindows() bool Exec(cmd string, opts ...exec.Option) error ExecOutput(cmd string, opts ...exec.Option) (string, error) - ExecStreams(cmd string, stdin io.ReadCloser, stdout io.Writer, stderr io.Writer, opts ...exec.Option) (Waiter, error) + ExecStreams(cmd string, stdin io.ReadCloser, stdout io.Writer, stderr io.Writer, opts ...exec.Option) (exec.Waiter, error) } // File is a file in the remote filesystem type File interface { fs.File - io.WriteCloser io.Seeker - Copy(dest io.Writer) (int64, error) - CopyFromN(src io.Reader, count int64, dest io.Writer) (int64, error) + io.ReadCloser + io.Writer } // Fsys is a filesystem on the remote host diff --git a/pkg/rigfs/windir.go b/pkg/rigfs/windir.go new file mode 100644 index 00000000..f2c0500c --- /dev/null +++ b/pkg/rigfs/windir.go @@ -0,0 +1,86 @@ +package rigfs + +import ( + "encoding/json" + "fmt" + "io/fs" + "os" + + ps "github.com/k0sproject/rig/pkg/powershell" +) + +var ( + _ fs.ReadDirFile = (*winDir)(nil) + _ File = (*winDir)(nil) +) + +// winDir is a directory on a Windows target. It implements fs.ReadDirFile. +type winDir struct { + winFileDirBase + buffer *dirEntryBuffer +} + +func (f *winDir) Read(_ []byte) (int, error) { + return 0, f.pathErr("read", fmt.Errorf("%w: is a directory", fs.ErrInvalid)) +} + +func (f *winDir) Seek(_ int64, _ int) (int64, error) { + return 0, f.pathErr("seek", fmt.Errorf("%w: is a directory", fs.ErrInvalid)) +} + +func (f *winDir) Write(_ []byte) (int, error) { + return 0, f.pathErr("write", fmt.Errorf("%w: is a directory", fs.ErrInvalid)) +} + +func (f *winDir) Close() error { + if f.closed { + return f.pathErr("close", fs.ErrClosed) + } + f.closed = true + return nil +} + +var statDirTemplate = ` +$items = Get-ChildItem -LiteralPath %s | Select-Object Name, FullName, LastWriteTime, Attributes, Mode, Length | ForEach-Object { + $isReadOnly = [bool]($_.Attributes -band [System.IO.FileAttributes]::ReadOnly) + $_ | Add-Member -NotePropertyName IsReadOnly -NotePropertyValue $isReadOnly -PassThru +} +if ($items -eq $null) { + throw "does not exist" +} +ConvertTo-Json -Compress -Depth 5 @($items) +` + +// ReadDir reads the contents of the directory and returns +// a slice of up to n fs.DirEntry values in directory order. +// Subsequent calls on the same file will yield further DirEntry values. +func (f *winDir) ReadDir(n int) ([]fs.DirEntry, error) { + if f.buffer == nil { + out, err := f.fsys.conn.ExecOutput(ps.Cmd(fmt.Sprintf(statDirTemplate, f.path)), f.fsys.opts...) + if err != nil { + return nil, fmt.Errorf("readdir: %w", err) + } + var fileinfos []*winFileInfo + if err := json.Unmarshal([]byte(out), &fileinfos); err != nil { + return nil, fmt.Errorf("decode readdir output: %w", err) + } + entries := make([]fs.DirEntry, len(fileinfos)) + for i, info := range fileinfos { + entries[i] = info + } + f.buffer = newDirEntryBuffer(entries) + } + return f.buffer.Next(n) +} + +func (f *winDir) open(flags int) error { + if f.closed { + return f.pathErr("open", fs.ErrClosed) + } + + if flags&(os.O_WRONLY|os.O_RDWR|os.O_APPEND|os.O_CREATE|os.O_TRUNC|os.O_EXCL) != 0 { + return f.pathErr("open", fmt.Errorf("%w: incompatible flags for directory access", fs.ErrInvalid)) + } + + return nil +} diff --git a/pkg/rigfs/winfile.go b/pkg/rigfs/winfile.go new file mode 100644 index 00000000..34ff523b --- /dev/null +++ b/pkg/rigfs/winfile.go @@ -0,0 +1,260 @@ +package rigfs + +import ( + "bufio" + _ "embed" + "encoding/json" + "errors" + "fmt" + "io" + "io/fs" + "os" + "strings" + + "github.com/k0sproject/rig/log" + ps "github.com/k0sproject/rig/pkg/powershell" +) + +//go:embed rigrcp.ps1 +var rigrcp string + +var ( + _ fs.File = (*winFile)(nil) + rigRcp = ps.CompressedCmd(rigrcp) + errEnded = errors.New("rigrcp ended") + errRemote = errors.New("remote error") +) + +type rcpResponse struct { + Err string `json:"error"` + N int64 `json:"n"` + Pos int64 `json:"pos"` +} + +type winFileDirBase struct { + withPath + fsys *WinFsys + closed bool +} + +// Stat returns the FileInfo for the remote file. +func (w *winFileDirBase) Stat() (fs.FileInfo, error) { + return w.fsys.Stat(w.path) +} + +// winFile is a file on a Windows target. It implements fs.File. +type winFile struct { + winFileDirBase + stdin io.WriteCloser + stdout *bufio.Reader + done chan struct{} +} + +// Seek sets the offset for the next Read or Write on the remote file. +// The whence argument controls the interpretation of offset. +// io.SeekStart = offset from the beginning of file +// io.SeekCurrent = offset from the current position +// io.SeekEnd = offset from the end of file +func (f *winFile) Seek(offset int64, whence int) (int64, error) { + if f.closed { + return 0, f.pathErr(OpSeek, fs.ErrClosed) + } + var seekOrigin string + switch whence { + case io.SeekStart: + seekOrigin = "Begin" + case io.SeekCurrent: + seekOrigin = "Current" + case io.SeekEnd: + seekOrigin = "End" + default: + return 0, f.pathErr(OpSeek, fmt.Errorf("%w: invalid whence %d", fs.ErrInvalid, whence)) + } + resp, err := f.command(fmt.Sprintf("s %d %s", offset, seekOrigin)) + if err != nil { + return 0, f.pathErr(OpSeek, err) + } + return resp.Pos, nil +} + +// Write writes len(p) bytes from p to the remote file. +func (f *winFile) Write(p []byte) (int, error) { + if f.closed { + return 0, f.pathErr(OpWrite, fs.ErrClosed) + } + _, err := f.command(fmt.Sprintf("w %d", len(p))) + if err != nil { + return 0, f.pathErr(OpWrite, err) + } + n, err := f.stdin.Write(p) + if err != nil { + return n, err //nolint:wrapcheck + } + log.Tracef("wrote %d bytes", n) + return n, nil +} + +// Read reads up to len(p) bytes from the remote file. +func (f *winFile) Read(p []byte) (int, error) { + if f.closed { + return 0, f.pathErr(OpRead, fs.ErrClosed) + } + resp, err := f.command(fmt.Sprintf("r %d", len(p))) + if err != nil { + return 0, err + } + if resp.N == 0 { + return 0, io.EOF + } + total := 0 + for total < int(resp.N) { + n, err := f.stdout.Read(p[total:resp.N]) + if err != nil { + return total, err //nolint:wrapcheck + } + total += n + } + log.Tracef("read %d bytes", total) + return total, nil +} + +func fAccess(flags int) string { + switch { + case flags&(os.O_WRONLY|os.O_TRUNC|os.O_APPEND) != 0: + return "Write" + case flags&os.O_RDWR != 0: + return "ReadWrite" + default: + return "Read" + } +} + +func fMode(flags int) string { + switch { + case flags&os.O_CREATE != 0: + if flags&os.O_EXCL != 0 { + return "CreateNew" + } + return "OpenOrCreate" + case flags&os.O_TRUNC != 0: + return "Truncate" + case flags&os.O_APPEND != 0: + return "Append" + default: + return "Open" + } +} + +func (f *winFile) open(flags int) error { + if f.closed { + return f.pathErr(OpOpen, fs.ErrClosed) + } + + stdinR, stdinW := io.Pipe() + stdoutR, stdoutW := io.Pipe() + stderrR, stderrW := io.Pipe() + f.stdin = stdinW + f.stdout = bufio.NewReader(stdoutR) + f.done = make(chan struct{}) + + cmd, err := f.fsys.conn.ExecStreams(rigRcp, stdinR, stdoutW, stderrW, f.fsys.opts...) + if err != nil { + return f.pathErr(OpOpen, fmt.Errorf("start file daemon: %w", err)) + } + go func() { + _, _ = io.Copy(io.Discard, stderrR) + }() + go func() { + log.Debugf("rigrcp started, waiting for exit") + err := cmd.Wait() + close(f.done) + log.Debugf("rigrcp ended") + if err != nil { + log.Errorf("rigrcp exited with error: %v", err) + } + f.closed = true + _ = stdinR.Close() + _ = stdinW.Close() + _ = stdoutR.Close() + _ = stdoutW.Close() + _ = stderrR.Close() + _ = stderrW.Close() + }() + + resp, err := f.command(fmt.Sprintf("o %s %s %s", fMode(flags), fAccess(flags), f.path)) + if err != nil { + return f.pathErr(OpOpen, err) + } + if resp.Err != "" { + return f.pathErr(OpOpen, fmt.Errorf("remote error: %s", resp.Err)) //nolint:goerr113 + } + + return nil +} + +func (f *winFile) command(cmd string) (*rcpResponse, error) { //nolint:cyclop + if f.closed { + return nil, f.pathErr(OpOpen, fs.ErrClosed) + } + resp := make(chan []byte, 1) + if cmd != "q" { + go func() { + b, err := f.stdout.ReadBytes(0) + if err != nil { + log.Errorf("failed to read response: %v", err) + close(resp) + return + } + resp <- b[:len(b)-1] // drop the zero byte + }() + } + log.Debugf("rigrcp command: %s", cmd) + _, err := fmt.Fprintf(f.stdin, "%s\n", cmd) + if err != nil { + return nil, f.pathErr(OpOpen, fmt.Errorf("write command: %w", err)) + } + if cmd == "q" { + return &rcpResponse{}, nil + } + select { + case <-f.done: + return nil, errEnded + case data, ok := <-resp: + out := &rcpResponse{} + if !ok || data == nil || len(data) == 0 { + return out, nil + } + if err := json.Unmarshal(data, out); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + if e := out.Err; e != "" { + if strings.HasPrefix(e, "eof") { + return nil, io.EOF + } + if strings.Contains(e, "does not exist") { + return nil, fs.ErrNotExist + } + return nil, fmt.Errorf("%w: %s", errRemote, e) + } + return out, nil + } +} + +func (f *winFile) Close() error { + resp, err := f.command("c") + if err != nil { + return f.pathErr(OpClose, err) + } + if resp.Err != "" { + return f.pathErr(OpClose, fmt.Errorf("%w: %s", errRemote, resp.Err)) + } + if resp.Pos != -1 { + return f.pathErr(OpClose, fmt.Errorf("%w: failed to close file", errRemote)) + } + _, err = f.command("q") + log.Tracef("rigrcp quit: %v", err) + f.stdin.Close() + f.closed = true + + return nil +} diff --git a/pkg/rigfs/winfileinfo.go b/pkg/rigfs/winfileinfo.go new file mode 100644 index 00000000..c009ce12 --- /dev/null +++ b/pkg/rigfs/winfileinfo.go @@ -0,0 +1,84 @@ +package rigfs + +import ( + "fmt" + "io/fs" + "strconv" + "strings" + "time" +) + +var _ fs.FileInfo = (*winFileInfo)(nil) + +type windowsFileInfoTime time.Time + +func (t *windowsFileInfoTime) UnmarshalJSON(b []byte) error { + strTime := strings.Trim(string(b), "\"\\/Date()") + milliseconds, err := strconv.ParseInt(strTime, 10, 64) + if err != nil { + return fmt.Errorf("decode time: %w", err) + } + + seconds := milliseconds / 1000 + nanoseconds := (milliseconds % 1000) * 1000000 + *t = windowsFileInfoTime(time.Unix(seconds, nanoseconds)) + return nil +} + +type winFileInfo struct { + Path string `json:"Name"` + Length int64 `json:"Length"` + IsReadOnly bool `json:"IsReadOnly"` + FullName string `json:"FullName"` + Extension string `json:"Extension"` + LastWriteTime windowsFileInfoTime `json:"LastWriteTime"` + Attributes int `json:"Attributes"` + FMode string `json:"Mode"` + Err string `json:"Err"` + fsys *WinFsys +} + +// Name returns the base name of the file. +func (fi *winFileInfo) Name() string { + parts := strings.Split(fi.Path, "\\") + return parts[len(parts)-1] +} + +// Size returns the length in bytes for regular files; system-dependent for others. +func (fi *winFileInfo) Size() int64 { + return fi.Length +} + +// Mode returns the file mode bits. +func (fi *winFileInfo) Mode() fs.FileMode { + if fi.IsReadOnly { + return 0o555 + } + return 0o777 +} + +// ModTime returns the modification time. +func (fi *winFileInfo) ModTime() time.Time { + return time.Time(fi.LastWriteTime) +} + +// IsDir is abbreviation for Mode().IsDir(). +func (fi *winFileInfo) IsDir() bool { + return strings.Contains(fi.FMode, "d") +} + +// Sys returns the underlying data source (can return nil). +func (fi *winFileInfo) Sys() any { + return fi.fsys +} + +// Info returns self, satisfying fs.DirEntry interface. +func (fi *winFileInfo) Info() (fs.FileInfo, error) { + return fi, nil +} + +// Type returns the type bits for the entry. +// The type bits are a subset of the usual FileMode bits, those returned by the FileMode.Type method. +func (fi *winFileInfo) Type() fs.FileMode { + return fi.Mode().Type() +} diff --git a/pkg/rigfs/winfsys.go b/pkg/rigfs/winfsys.go index e2f31f9a..4f279a29 100644 --- a/pkg/rigfs/winfsys.go +++ b/pkg/rigfs/winfsys.go @@ -1,518 +1,174 @@ package rigfs import ( - "bufio" - _ "embed" - "encoding/hex" "encoding/json" "errors" "fmt" - "io" "io/fs" "os" "strings" - "sync" - "time" "github.com/k0sproject/rig/exec" - "github.com/k0sproject/rig/log" ps "github.com/k0sproject/rig/pkg/powershell" ) -const bufSize = 32768 +var _ fs.FS = (*WinFsys)(nil) -var ( - // ErrNotRunning is returned when the rigrcp process is not running - ErrNotRunning = errors.New("rigrcp is not running") - // ErrRcpCommandFailed is returned when a command to the rigrcp process fails - ErrRcpCommandFailed = errors.New("rigrcp command failed") -) - -// rigWinRCPScript is a helper script for transferring files between local and remote systems -// -//go:embed rigrcp.ps1 -var rigWinRCPScript string - -var ( - _ fs.File = (*winFile)(nil) - _ fs.ReadDirFile = (*winDir)(nil) - _ fs.FS = (*WinFsys)(nil) -) - -// WinFsys is a fs.FS implementation for remote Windows hosts +// WinFsys is a fs.FS implemen{ type WinFsys struct { conn connection - rcp *winRCP - buf []byte - mu sync.Mutex -} - -type seekResponse struct { - Position int64 `json:"position"` -} - -type readResponse struct { - Bytes int64 `json:"bytes"` -} - -type sumResponse struct { - Sha256 string `json:"sha256"` -} - -type openResponse struct { - ID string `json:"id"` - Pos int64 `json:"pos"` - EOF bool `json:"eof"` - Name string `json:"name"` - IsDir bool `json:"isdir"` -} - -type rigrcpResponse struct { - Err error `json:"-"` - ErrString string `json:"error"` - Stat *FileInfo `json:"stat"` - Dir []*FileInfo `json:"dir"` - Seek *seekResponse `json:"seek"` - Read *readResponse `json:"read"` - Sum *sumResponse `json:"sum"` - Open *openResponse `json:"open"` -} - -func (r *rigrcpResponse) UnmarshalJSON(b []byte) error { - type rigresponse *rigrcpResponse - rr := rigresponse(r) - if err := json.Unmarshal(b, rr); err != nil { - return fmt.Errorf("%w: failed to unmarshal rigrcp response: %w", ErrCommandFailed, err) - } - if r.ErrString != "" { - r.Err = fmt.Errorf("%w: %s", ErrCommandFailed, strings.TrimSpace(r.ErrString)) - } - return nil + opts []exec.Option } // NewWindowsFsys returns a new fs.FS implementing filesystem for Windows targets func NewWindowsFsys(conn connection, opts ...exec.Option) *WinFsys { return &WinFsys{ conn: conn, - buf: make([]byte, bufSize), - rcp: &winRCP{conn: conn, opts: opts}, - } -} - -type winRCP struct { - conn connection - opts []exec.Option - mu sync.Mutex - done chan struct{} - stdin io.WriteCloser - stdout *bufio.Reader - stderr io.WriteCloser - running bool -} - -func (rcp *winRCP) run() error { - log.Debugf("starting rigrcp") - rcp.mu.Lock() - defer rcp.mu.Unlock() - - stdinR, stdinW := io.Pipe() - stdoutR, stdoutW := io.Pipe() - rcp.stdout = bufio.NewReader(stdoutR) - rcp.stdin = stdinW - rcp.stderr = os.Stderr - rcp.done = make(chan struct{}) - cmd := ps.CompressedCmd(rigWinRCPScript) - log.Tracef("rigrcp command size: %d", len(cmd)) - waiter, err := rcp.conn.ExecStreams(cmd, stdinR, stdoutW, rcp.stderr, rcp.opts...) - if err != nil { - return fmt.Errorf("%w: failed to start rigrcp: %w", ErrCommandFailed, err) - } - go func() { - rcp.running = true - - err := waiter.Wait() - log.Debugf("rigrcp exited") - rcp.running = false - if err != nil { - log.Errorf("rigrcp: %v", err) - } - close(rcp.done) - _ = rcp.stdin.Close() - _ = rcp.stderr.Close() - _ = stdoutR.Close() - }() - - time.Sleep(time.Second) - if !rcp.running { - return fmt.Errorf("%w: rigrcp failed to start", ErrCommandFailed) - } - log.Tracef("started rigrcp") - return nil -} - -func (rcp *winRCP) command(cmd string) (rigrcpResponse, error) { //nolint:cyclop - var res rigrcpResponse - if !rcp.running { - if err := rcp.run(); err != nil { - return res, err - } - } - rcp.mu.Lock() - defer rcp.mu.Unlock() - - resp := make(chan []byte, 1) - go func() { - b, err := rcp.stdout.ReadBytes(0) - if err != nil { - log.Errorf("failed to read response: %v", err) - close(resp) - return - } - log.Tracef("rigrcp raw response:\n%s", hex.Dump(b)) - resp <- b[:len(b)-1] // drop the zero byte - }() - - log.Tracef("rigrcp raw request:\n%s", hex.Dump([]byte(cmd+"\n"))) - if _, err := rcp.stdin.Write([]byte(cmd + "\n")); err != nil { - return res, fmt.Errorf("%w: %w", ErrRcpCommandFailed, err) + opts: opts, } - select { - case <-rcp.done: - return res, fmt.Errorf("%w: rigrcp exited", ErrRcpCommandFailed) - case data := <-resp: - if data == nil { - return res, nil - } - if len(data) == 0 { - return res, nil - } - if err := json.Unmarshal(data, &res); err != nil { - return res, fmt.Errorf("%w: failed to unmarshal response: %w", ErrRcpCommandFailed, err) - } - if res.Err != nil { - if res.Err.Error() == "command failed: eof" { - return res, io.EOF - } - if strings.Contains(res.Err.Error(), "\"file not found\"") { - return res, fs.ErrNotExist - } - } - return res, nil - } -} - -// winFile is a file on a Windows target. It implements fs.File. -type winFile struct { - fsys *WinFsys - id string - path string -} - -// Seek sets the offset for the next Read or Write on the remote file. -// The whence argument controls the interpretation of offset. -// 0 = offset from the beginning of file -// 1 = offset from the current position -// 2 = offset from the end of file -func (f *winFile) Seek(offset int64, whence int) (int64, error) { - resp, err := f.fsys.rcp.command(fmt.Sprintf("seek %s %d %d", f.id, offset, whence)) - if err != nil { - return -1, &fs.PathError{Op: "seek", Path: f.path, Err: fmt.Errorf("%w: seek: %w", ErrRcpCommandFailed, err)} - } - if resp.Seek == nil { - return -1, &fs.PathError{Op: "seek", Path: f.path, Err: fmt.Errorf("%w: seek response: %v", ErrRcpCommandFailed, resp)} - } - return resp.Seek.Position, nil } -// winDir is a directory on a Windows target. It implements fs.ReadDirFile. -type winDir struct { - winFile - entries []fs.DirEntry - hw int -} - -// ReadDir reads the contents of the directory and returns -// a slice of up to n fs.DirEntry values in directory order. -// Subsequent calls on the same file will yield further DirEntry values. -func (d *winDir) ReadDir(n int) ([]fs.DirEntry, error) { - if n == 0 { - return d.winFile.fsys.ReadDir(d.path) - } - if d.entries == nil { - entries, err := d.winFile.fsys.ReadDir(d.path) - if err != nil { - return nil, err +var statCmdTemplate = `if (Test-Path -LiteralPath %[1]s) { + $item = Get-Item -LiteralPath %[1]s | Select-Object Name, FullName, LastWriteTime, Attributes, Mode, Length | ForEach-Object { + $isReadOnly = [bool]($_.Attributes -band [System.IO.FileAttributes]::ReadOnly) + $_ | Add-Member -NotePropertyName IsReadOnly -NotePropertyValue $isReadOnly -PassThru } - d.entries = entries - d.hw = 0 - } - if d.hw >= len(d.entries) { - return nil, io.EOF - } - var min int - if n > len(d.entries)-d.hw { - min = len(d.entries) - d.hw - } else { - min = n - } - old := d.hw - d.hw += min - return d.entries[old:d.hw], nil -} - -// CopyFromN copies n bytes from the reader to the opened file on the target. -// The alt io.Writer parameter can be set to a non nil value if a progress bar or such -// is desired. -func (f *winFile) CopyFromN(src io.Reader, num int64, alt io.Writer) (int64, error) { - _, err := f.fsys.rcp.command(fmt.Sprintf("w %s %d", f.id, num)) - if err != nil { - return 0, &fs.PathError{Op: "copy-to", Path: f.path, Err: fmt.Errorf("%w: copy: %w", ErrRcpCommandFailed, err)} - } - var writer io.Writer - if alt != nil { - writer = io.MultiWriter(f.fsys.rcp.stdin, alt) + $item | ConvertTo-Json -Compress } else { - writer = f.fsys.rcp.stdin - } - copied, err := io.CopyN(writer, src, num) - if err != nil { - return copied, &fs.PathError{Op: "copy-to", Path: f.path, Err: fmt.Errorf("%w: copy stream: %w", ErrRcpCommandFailed, err)} - } - return copied, nil -} - -// Copy copies the complete remote file from the current file position to the supplied io.Writer. -func (f *winFile) Copy(dst io.Writer) (int64, error) { - resp, err := f.fsys.rcp.command(fmt.Sprintf("r %s -1", f.id)) - if errors.Is(err, io.EOF) { - return 0, io.EOF - } - if err != nil { - return 0, &fs.PathError{Op: "read", Path: f.path, Err: fmt.Errorf("%w: copy: %w", ErrRcpCommandFailed, err)} - } - if resp.Read == nil { - return 0, &fs.PathError{Op: "read", Path: f.path, Err: fmt.Errorf("%w: copy response: %v", ErrCommandFailed, resp)} - } - if resp.Read.Bytes == 0 { - return 0, io.EOF - } - var totalRead int64 - for totalRead < resp.Read.Bytes { - f.fsys.mu.Lock() - read, err := f.fsys.rcp.stdout.Read(f.fsys.buf) - totalRead += int64(read) - if err != nil { - f.fsys.mu.Unlock() - return totalRead, &fs.PathError{Op: "read", Path: f.path, Err: fmt.Errorf("%w: copy (read): %w", ErrRcpCommandFailed, err)} - } - _, err = dst.Write(f.fsys.buf[:read]) - f.fsys.mu.Unlock() - if err != nil { - return totalRead, &fs.PathError{Op: "write", Path: f.path, Err: fmt.Errorf("%w: copy (write): %w", ErrRcpCommandFailed, err)} - } - } - return totalRead, nil -} + Write-Output '{"Err":"does not exist"}' + }` -// Write writes len(p) bytes from p to the remote file. -func (f *winFile) Write(p []byte) (int, error) { - _, err := f.fsys.rcp.command(fmt.Sprintf("w %s %d", f.id, len(p))) - if errors.Is(err, io.EOF) { - return 0, io.EOF - } - if err != nil { - return 0, &fs.PathError{Op: "write", Path: f.path, Err: fmt.Errorf("%w: initiate write: %w", ErrRcpCommandFailed, err)} - } - written, err := f.fsys.rcp.stdin.Write(p) +// Stat returns fs.FileInfo for the remote file. +func (fsys *WinFsys) Stat(name string) (fs.FileInfo, error) { + out, err := fsys.conn.ExecOutput(ps.Cmd(fmt.Sprintf(statCmdTemplate, ps.DoubleQuotePath(name))), fsys.opts...) if err != nil { - return written, &fs.PathError{Op: "write", Path: f.path, Err: fmt.Errorf("%w: write error: %w", ErrRcpCommandFailed, err)} + return nil, &fs.PathError{Op: OpStat, Path: name, Err: fmt.Errorf("%w: %w", err, fs.ErrNotExist)} } - return written, nil -} -// Read reads up to len(p) bytes from the remote file. -func (f *winFile) Read(p []byte) (int, error) { - resp, err := f.fsys.rcp.command(fmt.Sprintf("r %s %d", f.id, len(p))) - if errors.Is(err, io.EOF) { - return 0, io.EOF - } - if err != nil { - return 0, &fs.PathError{Op: "read", Path: f.path, Err: fmt.Errorf("%w: read: %w", ErrRcpCommandFailed, err)} + fi := &winFileInfo{fsys: fsys} + if err := json.Unmarshal([]byte(out), fi); err != nil { + return nil, &fs.PathError{Op: OpStat, Path: name, Err: fmt.Errorf("%w: stat (parse)", err)} } - if resp.Read == nil { - return 0, &fs.PathError{Op: "read", Path: f.path, Err: fmt.Errorf("%w: read response: %v", ErrRcpCommandFailed, resp)} - } - if resp.Read.Bytes == 0 { - return 0, io.EOF - } - var totalRead int64 - for totalRead < resp.Read.Bytes { - read, err := f.fsys.rcp.stdout.Read(p[totalRead:resp.Read.Bytes]) - log.Tracef("read %d bytes from %s", read, f.path) - totalRead += int64(read) - if err != nil { - return int(totalRead), &fs.PathError{Op: "read", Path: f.path, Err: fmt.Errorf("%w: read: %w", ErrRcpCommandFailed, err)} + if fi.Err != "" { + if strings.Contains(fi.Err, "does not exist") { + return nil, &fs.PathError{Op: OpStat, Path: name, Err: fs.ErrNotExist} } + return nil, &fs.PathError{Op: OpStat, Path: name, Err: fmt.Errorf("stat: %v", fi.Err)} //nolint:goerr113 } - log.Tracef("read %d bytes total", totalRead) - return int(totalRead), nil -} - -// Stat returns the FileInfo for the remote file. -func (f *winFile) Stat() (fs.FileInfo, error) { - return f.fsys.Stat(f.path) -} - -// Close closes the remote file. -func (f *winFile) Close() error { - _, err := f.fsys.rcp.command(fmt.Sprintf("c %s", f.id)) - if err != nil { - return &fs.PathError{Op: "close", Path: f.path, Err: fmt.Errorf("%w: close: %w", ErrRcpCommandFailed, err)} - } - return nil -} - -// Open opens the named file for reading and returns fs.File. -// Use OpenFile to get a file that can be written to or if you need any of the methods not -// available on fs.File interface without type assertion. -func (fsys *WinFsys) Open(name string) (fs.File, error) { - f, err := fsys.OpenFile(name, os.O_RDONLY, 0) - if err != nil { - return nil, err - } - return f, nil -} - -func toWindowsPath(name string) string { - return strings.Join(strings.Split(name, "/"), "\\") -} - -// OpenFile opens the named remote file with the specified flags. os.O_EXCL and permission bits are ignored on Windows. -// For a description of the flags, see https://pkg.go.dev/os#pkg-constants -func (fsys *WinFsys) OpenFile(name string, flags int, _ fs.FileMode) (File, error) { - var modeStr string - switch { - case flags&os.O_WRONLY == os.O_WRONLY: - modeStr = "w" - case flags&os.O_RDWR == os.O_RDWR: - modeStr = "rw" - case flags&os.O_APPEND == os.O_APPEND: - modeStr = "a" - case flags&os.O_CREATE == os.O_CREATE: - modeStr = "c" - case flags&(os.O_WRONLY|os.O_RDWR) == 0: - modeStr = "ro" - default: - return nil, &fs.PathError{Op: "open", Path: name, Err: fmt.Errorf("%w: invalid mode: %d", ErrRcpCommandFailed, flags)} - } - - name = toWindowsPath(name) - - log.Debugf("opening remote file %s (mode %s)", name, modeStr) - resp, err := fsys.rcp.command(fmt.Sprintf("o %s %s", modeStr, name)) - log.Debugf("rigrcp response: %+v : %v", resp, err) - if err != nil { - return nil, &fs.PathError{Op: "open", Path: name, Err: fs.ErrNotExist} - } - if resp.Open == nil { - return nil, &fs.PathError{Op: "open", Path: name, Err: fmt.Errorf("%w: open response: %v", ErrRcpCommandFailed, resp)} - } - file := &winFile{fsys: fsys, path: resp.Open.Name, id: resp.Open.ID} - if resp.Open.IsDir { - return &winDir{winFile: *file}, nil - } - return file, nil -} - -// Stat returns fs.FileInfo for the remote file. -func (fsys *WinFsys) Stat(name string) (fs.FileInfo, error) { - name = toWindowsPath(name) - resp, err := fsys.rcp.command(fmt.Sprintf("stat %s", name)) - if err != nil { - return nil, &fs.PathError{Op: "stat", Path: name, Err: fmt.Errorf("%w: stat %s: %w", ErrRcpCommandFailed, name, err)} - } - if resp.Stat == nil { - return nil, &fs.PathError{Op: "stat", Path: name, Err: fmt.Errorf("%w: stat response: %v", ErrRcpCommandFailed, resp)} - } - return resp.Stat, nil + return fi, nil } // Sha256 returns the SHA256 hash of the remote file. func (fsys *WinFsys) Sha256(name string) (string, error) { - name = toWindowsPath(name) - sum, err := fsys.conn.ExecOutput(ps.Cmd(fmt.Sprintf("(Get-FileHash %s -Algorithm SHA256).Hash.ToLower()", ps.DoubleQuote(name)))) + sum, err := fsys.conn.ExecOutput(ps.Cmd(fmt.Sprintf("(Get-FileHash %s -Algorithm SHA256).Hash.ToLower()", ps.DoubleQuotePath(name)))) if err != nil { - return "", &fs.PathError{Op: "sum", Path: name, Err: fmt.Errorf("%w: sha256sum: %w", ErrRcpCommandFailed, err)} + return "", &fs.PathError{Op: "sum", Path: name, Err: fmt.Errorf("sha256sum: %w", err)} } return sum, nil } // ReadDir reads the directory named by dirname and returns a list of directory entries. func (fsys *WinFsys) ReadDir(name string) ([]fs.DirEntry, error) { - name = toWindowsPath(name) - resp, err := fsys.rcp.command(fmt.Sprintf("dir %s", name)) + f, err := fsys.OpenFile(name, os.O_RDONLY, 0) if err != nil { - return nil, &fs.PathError{Op: "readdir", Path: name, Err: fmt.Errorf("%w: readdir: %w: %w", ErrRcpCommandFailed, err, fs.ErrNotExist)} - } - if resp.Dir == nil { - return nil, nil + return nil, &fs.PathError{Op: "readdir", Path: name, Err: err} } - entries := make([]fs.DirEntry, len(resp.Dir)) - for i, entry := range resp.Dir { - entries[i] = entry + defer f.Close() + dir, ok := f.(*winDir) + if !ok { + return nil, &fs.PathError{Op: "readdir", Path: name, Err: fmt.Errorf("readdir: %w", fs.ErrInvalid)} } - return entries, nil + + return dir.ReadDir(-1) } // Remove deletes the named file or (empty) directory. func (fsys *WinFsys) Remove(name string) error { - name = toWindowsPath(name) - if existing, err := fsys.Stat(name); err == nil && existing.IsDir() { return fsys.removeDir(name) } - if err := fsys.conn.Exec(fmt.Sprintf("del %s", ps.DoubleQuote(name))); err != nil { - return fmt.Errorf("%w: remove %s: %w", ErrCommandFailed, name, err) + if err := fsys.conn.Exec(fmt.Sprintf("del %s", ps.DoubleQuotePath(name))); err != nil { + return fmt.Errorf("remove %s: %w", name, err) } + return nil } // RemoveAll deletes the named file or directory and all its child items func (fsys *WinFsys) RemoveAll(name string) error { - name = toWindowsPath(name) - if existing, err := fsys.Stat(name); err == nil && existing.IsDir() { return fsys.removeDirAll(name) } - if err := fsys.conn.Exec(fmt.Sprintf("del %s", ps.DoubleQuote(name))); err != nil { - return fmt.Errorf("%w: remove all %s: %w", ErrCommandFailed, name, err) + if err := fsys.conn.Exec(fmt.Sprintf("del %s", ps.DoubleQuotePath(name))); err != nil { + return fmt.Errorf("remove %s: %w", name, err) } + return nil } func (fsys *WinFsys) removeDir(name string) error { - if err := fsys.conn.Exec(fmt.Sprintf("rmdir /q %s", ps.DoubleQuote(name))); err != nil { - return fmt.Errorf("%w: rmdir %s: %w", ErrCommandFailed, name, err) + if err := fsys.conn.Exec(fmt.Sprintf("rmdir /q %s", ps.DoubleQuotePath(name))); err != nil { + return fmt.Errorf("rmdir %s: %w", name, err) } return nil } func (fsys *WinFsys) removeDirAll(name string) error { - if err := fsys.conn.Exec(fmt.Sprintf("rmdir /s /q %s", ps.DoubleQuote(name))); err != nil { - return fmt.Errorf("%w: rmdir %s: %w", ErrCommandFailed, name, err) + if err := fsys.conn.Exec(fmt.Sprintf("rmdir /s /q %s", ps.DoubleQuotePath(name))); err != nil { + return fmt.Errorf("rmdir %s: %w", name, err) } + return nil } // MkDirAll creates a directory named path, along with any necessary parents. The permission bits are ignored on Windows. func (fsys *WinFsys) MkDirAll(name string, _ fs.FileMode) error { - name = toWindowsPath(name) - - if err := fsys.conn.Exec(ps.Cmd(fmt.Sprintf("New-Item -ItemType Directory -Force -Path %s", ps.DoubleQuote(name)))); err != nil { - return fmt.Errorf("%w: mkdir %s: %w", ErrCommandFailed, name, err) + if err := fsys.conn.Exec(ps.Cmd(fmt.Sprintf("New-Item -ItemType Directory -Force -Path %s", ps.DoubleQuotePath(name)))); err != nil { + return fmt.Errorf("mkdir %s: %w", name, err) } return nil } + +type opener interface { + open(flags int) error +} + +// Open opens the named file for reading and returns fs.File. +// Use OpenFile to get a file that can be written to or if you need any of the methods not +// available on fs.File interface without type assertion. +func (fsys *WinFsys) Open(name string) (fs.File, error) { + f, err := fsys.OpenFile(name, os.O_RDONLY, 0) + if err != nil { + return nil, err + } + + return f, nil +} + +// OpenFile opens the named remote file with the specified flags. os.O_EXCL and permission bits are ignored on Windows. +// For a description of the flags, see https://pkg.go.dev/os#pkg-constants +func (fsys *WinFsys) OpenFile(name string, flags int, _ fs.FileMode) (File, error) { + name = ps.ToWindowsPath(name) + fi, err := fsys.Stat(name) + if err != nil && !errors.Is(err, fs.ErrNotExist) { + return nil, &fs.PathError{Op: OpOpen, Path: name, Err: fmt.Errorf("stat: %w", err)} + } + var o opener + if fi != nil && fi.IsDir() { + o = &winDir{winFileDirBase: winFileDirBase{withPath: withPath{name}, fsys: fsys}} + } else { + o = &winFile{winFileDirBase: winFileDirBase{withPath: withPath{name}, fsys: fsys}} + } + if err := o.open(flags); err != nil { + return nil, fmt.Errorf("open: %w", err) + } + f, ok := o.(File) + if !ok { + return nil, &fs.PathError{Op: OpOpen, Path: name, Err: fmt.Errorf("%w: open: %w", ErrCommandFailed, fs.ErrInvalid)} + } + + return f, nil +} diff --git a/pkg/rigfs/withname.go b/pkg/rigfs/withname.go new file mode 100644 index 00000000..e25d7512 --- /dev/null +++ b/pkg/rigfs/withname.go @@ -0,0 +1,24 @@ +package rigfs + +import "io/fs" + +const ( + OpClose = "close" // OpClose Close operation + OpOpen = "open" // OpOpen Open operation + OpRead = "read" // OpRead Read operation + OpSeek = "seek" // OpSeek Seek operation + OpStat = "stat" // OpStat Stat operation + OpWrite = "write" // OpWrite Write operation +) + +type withPath struct { + path string +} + +func (w *withPath) Name() string { + return w.path +} + +func (w *withPath) pathErr(op string, err error) error { + return &fs.PathError{Op: op, Path: w.path, Err: err} +} diff --git a/ssh.go b/ssh.go index 4dfd2d98..2fbafd65 100644 --- a/ssh.go +++ b/ssh.go @@ -8,7 +8,6 @@ import ( "io" "net" "os" - "slices" "sort" "strconv" "strings" @@ -16,12 +15,12 @@ import ( "github.com/acarl005/stripansi" "github.com/creasty/defaults" - "github.com/google/shlex" "github.com/k0sproject/rig/exec" "github.com/k0sproject/rig/log" "github.com/k0sproject/rig/pkg/ssh/agent" "github.com/k0sproject/rig/pkg/ssh/hostkey" "github.com/kevinburke/ssh_config" + "github.com/mattn/go-shellwords" ssh "golang.org/x/crypto/ssh" "golang.org/x/term" ) @@ -43,7 +42,8 @@ type SSH struct { // authMethods, err := rig.ParseSSHPrivateKey(key, rig.DefaultPassphraseCallback) AuthMethods []ssh.AuthMethod `yaml:"-"` - name string + alias string + name string isWindows bool knowOs bool @@ -122,6 +122,27 @@ func expandAndValidatePath(path string) (string, error) { return path, nil } +// compact replaces consecutive runs of equal elements with a single copy. +// This is like the uniq command found on Unix. +// +// Taken from stdlib's slices package, to work around a problem on github actions +// (package slices is not in GOROOT (/opt/hostedtoolcache/go/1.20.12/x64/src/slices) +func compact[S ~[]E, E comparable](slice S) S { + if len(slice) < 2 { + return slice + } + i := 1 + for k := 1; k < len(slice); k++ { + if slice[k] != slice[k-1] { + if i != k { + slice[i] = slice[k] + } + i++ + } + } + return slice[:i] +} + func (c *SSH) keypathsFromConfig() []string { log.Tracef("%s: trying to get a keyfile path from ssh config", c) idf := c.getConfigAll("IdentityFile") @@ -130,7 +151,7 @@ func (c *SSH) keypathsFromConfig() []string { // To work around this, the hard coded list of known defaults are appended to the list idf = append(idf, defaultKeypaths...) sort.Strings(idf) - idf = slices.Compact(idf) + idf = compact(idf) if len(idf) > 0 { log.Tracef("%s: detected %d identity file paths from ssh config: %v", c, len(idf), idf) @@ -148,7 +169,7 @@ func (c *SSH) initGlobalDefaults() { // To work around this, the hard coded list of known defaults are appended to the list dummyHostIdentityFiles = append(dummyHostIdentityFiles, defaultKeypaths...) sort.Strings(dummyHostIdentityFiles) - dummyHostIdentityFiles = slices.Compact(dummyHostIdentityFiles) + dummyHostIdentityFiles = compact(dummyHostIdentityFiles) for _, keyPath := range dummyHostIdentityFiles { if expanded, err := expandAndValidatePath(keyPath); err == nil { dummyhostKeyPaths = append(dummyhostKeyPaths, expanded) @@ -189,6 +210,21 @@ func (c *SSH) SetDefaults() { paths := c.keypathsFromConfig() + if c.Port == 0 || c.Port == 22 { + ports := c.getConfigAll("Port") + if len(ports) > 0 { + if p, err := strconv.Atoi(ports[0]); err == nil { + c.Port = p + } + } + } + + addrs := c.getConfigAll("HostName") + if len(addrs) > 0 { + c.alias = c.Address + c.Address = addrs[0] + } + for _, p := range paths { expanded, err := expandAndValidatePath(p) if err != nil { @@ -221,11 +257,9 @@ func (c *SSH) IPAddress() string { // you can override it with your own implementation for testing purposes var SSHConfigGetAll = ssh_config.GetAll -// try with port, if no results, try without func (c *SSH) getConfigAll(key string) []string { - dst := net.JoinHostPort(c.Address, strconv.Itoa(c.Port)) - if val := SSHConfigGetAll(dst, key); len(val) > 0 { - return val + if c.alias != "" { + return SSHConfigGetAll(c.alias, key) } return SSHConfigGetAll(c.Address, key) } @@ -322,8 +356,9 @@ func (c *SSH) hostkeyCallback() (ssh.HostKeyCallback, error) { kfs := c.getConfigAll("UserKnownHostsFile") // splitting the result as for some reason ssh_config sometimes seems to // return a single string containing space separated paths - if files, err := shlex.Split(strings.Join(kfs, " ")); err == nil { + if files, err := shellwords.Parse(strings.Join(kfs, " ")); err == nil { for _, f := range files { + log.Tracef("%s: trying known_hosts file from ssh config %s", c, f) exp, err := expandPath(f) if err == nil { khPath = exp @@ -523,7 +558,7 @@ const ( // ExecStreams executes a command on the remote host and uses the passed in streams for stdin, stdout and stderr. It returns a Waiter with a .Wait() function that // blocks until the command finishes and returns an error if the exit code is not zero. -func (c *SSH) ExecStreams(cmd string, stdin io.ReadCloser, stdout, stderr io.Writer, opts ...exec.Option) (waiter, error) { +func (c *SSH) ExecStreams(cmd string, stdin io.ReadCloser, stdout, stderr io.Writer, opts ...exec.Option) (exec.Waiter, error) { if c.client == nil { return nil, ErrNotConnected } @@ -534,6 +569,8 @@ func (c *SSH) ExecStreams(cmd string, stdin io.ReadCloser, stdout, stderr io.Wri return nil, fmt.Errorf("%w: build command: %w", ErrCommandFailed, err) } + execOpts.LogCmd(c.String(), cmd) + session, err := c.client.NewSession() if err != nil { return nil, fmt.Errorf("%w: create new session: %w", ErrCommandFailed, err) @@ -551,7 +588,7 @@ func (c *SSH) ExecStreams(cmd string, stdin io.ReadCloser, stdout, stderr io.Wri } // Exec executes a command on the host -func (c *SSH) Exec(cmd string, opts ...exec.Option) error { //nolint:cyclop +func (c *SSH) Exec(cmd string, opts ...exec.Option) error { //nolint:gocognit,cyclop execOpts := exec.Build(opts...) session, err := c.client.NewSession() if err != nil { @@ -617,7 +654,7 @@ func (c *SSH) Exec(cmd string, opts ...exec.Option) error { //nolint:cyclop } }() - gotErrors := false + var errors []string wg.Add(1) go func() { @@ -625,12 +662,14 @@ func (c *SSH) Exec(cmd string, opts ...exec.Option) error { //nolint:cyclop outputScanner := bufio.NewScanner(stderr) for outputScanner.Scan() { - gotErrors = true - execOpts.AddOutput(c.String(), "", outputScanner.Text()+"\n") + msg := outputScanner.Text() + if msg != "" { + errors = append(errors, msg) + execOpts.LogErrorf("%s: %s", c, msg) + } } if err := outputScanner.Err(); err != nil { - gotErrors = true execOpts.LogErrorf("%s: %s", c, err.Error()) } }() @@ -642,8 +681,8 @@ func (c *SSH) Exec(cmd string, opts ...exec.Option) error { //nolint:cyclop return fmt.Errorf("ssh session wait: %w", err) } - if c.knowOs && c.isWindows && (!execOpts.AllowWinStderr && gotErrors) { - return fmt.Errorf("%w: data in stderr", ErrCommandFailed) + if c.knowOs && c.isWindows && (!execOpts.AllowWinStderr && len(errors) > 0) { + return fmt.Errorf("%w: received data in stderr: %s", ErrCommandFailed, strings.Join(errors, "\n")) } return nil diff --git a/test/Makefile b/test/Makefile index e49c02ed..82614c82 100644 --- a/test/Makefile +++ b/test/Makefile @@ -6,7 +6,7 @@ REPLICAS ?= 1 LINUX_IMAGE ?= "quay.io/k0sproject/bootloose-ubuntu20.04" .PHONY: test -test: rigtest +test: gomod ./test.sh bootloose := $(shell which bootloose) @@ -24,16 +24,16 @@ ifeq ($(sshkeygen),) $(error 'ssh-keygen' NOT found in path, please install it and re-run) endif -.PHONY: rigtest -rigtest: - go build -o rigtest ../cmd/rigtest - $(bootloose): go install github.com/k0sproject/bootloose@latest .ssh: mkdir -p .ssh +.PHONY: gomod +gomod: + go mod download + .ssh/identity: .ssh rm -f .ssh/identity ssh-keygen -t $(KEY_TYPE) -b $(KEY_SIZE) -f .ssh/identity -N $(KEY_PASSPHRASE) @@ -62,7 +62,7 @@ delete-host: bootloose.yaml .PHONY: clean clean: delete-host - rm -f bootloose.yaml identity rigtest + rm -f bootloose.yaml identity rm -rf .ssh docker network rm bootloose-cluster || true @@ -71,9 +71,9 @@ sshport: @$(bootloose) show node0 -o json|grep hostPort|grep -oE "[0-9]+" .PHONY: run -run: rigtest create-host - ./rigtest \ - -host 127.0.0.1:$(shell $(MAKE) sshport) \ - -keypath $(KEY_PATH) \ +run: create-host gomod + go test -v ./... -args \ + -host 127.0.0.1 \ + -port $(shell $(MAKE) sshport) \ + -ssh-keypath $(KEY_PATH) \ -user root - diff --git a/test/rig_test.go b/test/rig_test.go new file mode 100644 index 00000000..6b7284b1 --- /dev/null +++ b/test/rig_test.go @@ -0,0 +1,707 @@ +package test + +import ( + "bytes" + "crypto/rand" + "crypto/sha256" + "encoding/hex" + "flag" + "fmt" + "io" + "io/fs" + "os" + "path" + "strings" + "testing" + "time" + + "github.com/k0sproject/rig" + "github.com/k0sproject/rig/exec" + rigos "github.com/k0sproject/rig/os" + "github.com/k0sproject/rig/os/registry" + _ "github.com/k0sproject/rig/os/support" + "github.com/k0sproject/rig/pkg/rigfs" + "github.com/kevinburke/ssh_config" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +// Define variables directly +var ( + targetHost string + targetPort int + username string + protocol string + keyPath string + configPath string + password string + useHTTPS bool + onlyConnect bool + privateKey string + enableMultiplex bool +) + +func pathBase(p string) string { + return path.Base(strings.ReplaceAll(p, "\\", "/")) +} + +func pathDir(p string) string { + return path.Dir(strings.ReplaceAll(p, "\\", "/")) +} + +func TestMain(m *testing.M) { + flag.StringVar(&targetHost, "host", "", "target host") + flag.IntVar(&targetPort, "port", 22, "target host port (defaulted based on protocol)") + flag.StringVar(&username, "user", "root", "user name") + flag.StringVar(&protocol, "protocol", "ssh", "ssh/winrm/localhost/openssh") + flag.StringVar(&keyPath, "ssh-keypath", "", "ssh keypath") + flag.StringVar(&configPath, "ssh-configpath", "", "ssh config path") + flag.StringVar(&privateKey, "ssh-private-key", "", "ssh private key") + flag.StringVar(&password, "winrm-password", "", "winrm password") + flag.BoolVar(&useHTTPS, "winrm-https", false, "use https for winrm") + flag.BoolVar(&enableMultiplex, "openssh-multiplex", true, "use ssh multiplexing") + flag.BoolVar(&onlyConnect, "connect", false, "only connect to host, dont run other tests") + + // Parse the flags + flag.Parse() + + if targetHost == "" { + // no host, nothing to test + return + } + + if targetPort == 22 && protocol == "winrm" { + if useHTTPS { + targetPort = 5986 + } else { + targetPort = 5985 + } + } + + if configPath != "" { + f, err := os.Open(configPath) + if err != nil { + panic(err) + } + cfg, err := ssh_config.Decode(f) + if err != nil { + panic(err) + } + rig.SSHConfigGetAll = func(dst, key string) []string { + res, err := cfg.GetAll(dst, key) + if err != nil { + return nil + } + return res + } + } + + // Run tests + os.Exit(m.Run()) +} + +func TestConnect(t *testing.T) { + if !onlyConnect { + t.Skip("skip") + return + } + + h := GetHost() + err := retry(func() error { return h.Connect() }) + require.NoError(t, err) + h.Disconnect() +} + +func TestConfigurerSuite(t *testing.T) { + if onlyConnect { + t.Skip("only connect") + return + } + suite.Run(t, &ConfigurerSuite{ConnectedSuite: ConnectedSuite{Host: GetHost()}}) +} + +func TestFsysSuite(t *testing.T) { + if onlyConnect { + t.Skip("only connect") + return + } + + h := GetHost() + t.Run("No sudo", func(t *testing.T) { + suite.Run(t, &FsysSuite{ConnectedSuite: ConnectedSuite{Host: h}}) + }) + + t.Run("Sudo", func(t *testing.T) { + suite.Run(t, &FsysSuite{ConnectedSuite: ConnectedSuite{Host: h}, sudo: true}) + }) +} + +type configurer interface { + WriteFile(rigos.Host, string, string, string) error + LineIntoFile(rigos.Host, string, string, string) error + ReadFile(rigos.Host, string) (string, error) + FileExist(rigos.Host, string) bool + DeleteFile(rigos.Host, string) error + Stat(rigos.Host, string, ...exec.Option) (*rigos.FileInfo, error) + Touch(rigos.Host, string, time.Time, ...exec.Option) error + MkDir(rigos.Host, string, ...exec.Option) error + Sha256sum(rigos.Host, string, ...exec.Option) (string, error) +} + +// Host is a host that utilizes rig for connections +type Host struct { + rig.Connection + + Configurer configurer +} + +// LoadOS is a function that assigns a OS support package to the host and +// typecasts it to a suitable interface +func (h *Host) LoadOS() error { + bf, err := registry.GetOSModuleBuilder(*h.OSVersion) + if err != nil { + return err + } + + h.Configurer = bf().(configurer) + + return nil +} + +func retry(fn func() error) error { + var err error + for i := 0; i < 3; i++ { + err = fn() + if err == nil { + return nil + } + time.Sleep(2 * time.Second) + } + return err +} + +func GetHost() *Host { + h := &Host{} + switch protocol { + case "ssh": + h.SSH = &rig.SSH{ + Address: targetHost, + Port: targetPort, + User: username, + } + + if privateKey != "" { + authM, err := rig.ParseSSHPrivateKey([]byte(privateKey), rig.DefaultPasswordCallback) + if err != nil { + panic(err) + } + h.SSH.AuthMethods = authM + } + + if keyPath != "" { + h.SSH.KeyPath = &keyPath + } + case "winrm": + h.WinRM = &rig.WinRM{ + Address: targetHost, + Port: targetPort, + User: username, + UseHTTPS: useHTTPS, + Insecure: true, + Password: password, + } + case "localhost": + h.Localhost = &rig.Localhost{Enabled: true} + case "openssh": + h.OpenSSH = &rig.OpenSSH{ + Address: targetHost, + DisableMultiplexing: !enableMultiplex, + } + if targetPort != 22 { + h.OpenSSH.Port = &targetPort + } + + if keyPath != "" { + h.OpenSSH.KeyPath = &keyPath + } + if username != "" { + h.OpenSSH.User = &username + } + if configPath != "" { + h.OpenSSH.ConfigPath = &configPath + } + default: + panic("unknown protocol") + } + return h +} + +type SuiteLogger struct { + t *testing.T +} + +func (s *SuiteLogger) Tracef(msg string, args ...interface{}) { + s.t.Log(fmt.Sprintf("%s TRACE %s", time.Now(), fmt.Sprintf(msg, args...))) +} +func (s *SuiteLogger) Debugf(msg string, args ...interface{}) { + s.t.Log(fmt.Sprintf("%s DEBUG %s", time.Now(), fmt.Sprintf(msg, args...))) +} +func (s *SuiteLogger) Infof(msg string, args ...interface{}) { + s.t.Log(fmt.Sprintf("%s INFO %s", time.Now(), fmt.Sprintf(msg, args...))) +} +func (s *SuiteLogger) Warnf(msg string, args ...interface{}) { + s.t.Log(fmt.Sprintf("%s WARN %s", time.Now(), fmt.Sprintf(msg, args...))) +} +func (s *SuiteLogger) Errorf(msg string, args ...interface{}) { + s.t.Log(fmt.Sprintf("%s ERROR %s", time.Now(), fmt.Sprintf(msg, args...))) +} + +type ConnectedSuite struct { + suite.Suite + tempDir string + count int + Host *Host +} + +func (s *ConnectedSuite) SetupSuite() { + rig.SetLogger(&SuiteLogger{s.T()}) + err := retry(func() error { return s.Host.Connect() }) + s.Require().NoError(err) + s.Require().NoError(s.Host.LoadOS()) + s.tempDir = "tmp.rig-test." + time.Now().Format("20060102150405") + s.Require().NoError(s.Host.Fsys().MkDirAll(s.tempDir, 0755)) +} + +func (s *ConnectedSuite) TearDownSuite() { + if s.Host == nil { + return + } + _ = s.Host.Fsys().RemoveAll(s.tempDir) + s.Host.Disconnect() +} + +func (s *ConnectedSuite) TempPath(args ...string) string { + if len(args) == 0 { + s.count++ + return fmt.Sprintf("%s/testfile.%d", s.tempDir, s.count) + } + args[0] = fmt.Sprintf("%s/%s", s.tempDir, args[0]) + return strings.Join(args, "/") +} + +type ConfigurerSuite struct { + ConnectedSuite +} + +func (s *ConfigurerSuite) TestStat() { + s.Run("File does not exist", func() { + stat, err := s.Host.Configurer.Stat(s.Host, s.TempPath("doesnotexist")) + s.Nil(stat) + s.Error(err) + }) + + s.Run("File exists", func() { + f := s.TempPath() + s.Run("Create file", func() { + s.Require().NoError(s.Host.Configurer.Touch(s.Host, f, time.Now())) + }) + + stat, err := s.Host.Configurer.Stat(s.Host, f) + s.Require().NoError(err) + s.True(strings.HasSuffix(f, stat.Name())) // Name() returns Basename + }) +} + +func (s *ConfigurerSuite) TestTouch() { + f := s.TempPath() + now := time.Now() + for _, tt := range []time.Time{now, now.Add(1 * time.Hour)} { + s.Run("Update timestamp "+tt.String(), func() { + s.Require().NoError(s.Host.Configurer.Touch(s.Host, f, now)) + }) + + s.Run("File exists and has correct timestamp "+tt.String(), func() { + stat, err := s.Host.Configurer.Stat(s.Host, f) + s.Require().NoError(err) + s.NotNil(stat) + s.Equal(now.Unix(), stat.ModTime().Unix()) + }) + } +} + +func (s *ConfigurerSuite) TestFileAccess() { + f := s.TempPath() + s.Run("File does not exist", func() { + s.False(s.Host.Configurer.FileExist(s.Host, f)) + }) + + s.Run("Write file", func() { + s.Require().NoError(s.Host.Configurer.WriteFile(s.Host, f, "test\ntest2\ntest3", "0644")) + }) + + s.Run("File exists", func() { + s.True(s.Host.Configurer.FileExist(s.Host, f)) + }) + + s.Run("Read file and verify contents", func() { + content, err := s.Host.Configurer.ReadFile(s.Host, f) + s.Require().NoError(err) + s.Equal("test\ntest2\ntest3", content) + }) + + s.Run("Replace line in file", func() { + s.Require().NoError(s.Host.Configurer.LineIntoFile(s.Host, f, "test2", "test4")) + }) + + s.Run("Re-read file and verify contents", func() { + content, err := s.Host.Configurer.ReadFile(s.Host, f) + s.Require().NoError(err) + // TODO: LineIntoFile adds a trailing newline + s.Equal("test\ntest4\ntest3", strings.TrimSpace(content)) + }) + + s.Run("Delete file", func() { + s.Require().NoError(s.Host.Configurer.DeleteFile(s.Host, f)) + }) + + s.Run("File does not exist", func() { + s.False(s.Host.Configurer.FileExist(s.Host, f)) + }) +} + +func testFile(size int64) (string, error) { + // Create a temporary file. + file, err := os.CreateTemp("", "rigtest.*.dat") + if err != nil { + return "", err + } + defer file.Close() + + // Write random data to the file. + _, err = io.CopyN(file, rand.Reader, size) + if err != nil { + return "", err + } + + return file.Name(), nil +} + +func (s *ConfigurerSuite) TestUpload() { + for _, size := range []int64{500, 100 * 1024, 1024 * 1024} { + s.Run(fmt.Sprintf("File size %d", size), func() { + fn, err := testFile(size) + s.Require().NoError(err) + defer os.Remove(fn) + defer s.Host.Configurer.DeleteFile(s.Host, s.TempPath(pathBase(fn))) + + s.Run("Upload file", func() { + s.Require().NoError(s.Host.Upload(fn, s.TempPath(pathBase(fn)))) + }) + + s.Run("Verify file size", func() { + stat, err := s.Host.Configurer.Stat(s.Host, s.TempPath(pathBase(fn))) + s.Require().NoError(err) + s.Require().NotNil(stat) + s.Equal(size, stat.Size()) + }) + + s.Run("Verify file contents", func() { + sum, err := s.Host.Configurer.Sha256sum(s.Host, s.TempPath(pathBase(fn))) + s.Require().NoError(err) + sha := sha256.New() + f, err := os.Open(fn) + s.Require().NoError(err) + _, err = io.Copy(sha, f) + s.Require().NoError(err) + s.Equal(hex.EncodeToString(sha.Sum(nil)), sum) + }) + + }) + } +} + +type FsysSuite struct { + ConnectedSuite + sudo bool + fsys rigfs.Fsys +} + +func (s *FsysSuite) SetupSuite() { + s.ConnectedSuite.SetupSuite() + if s.sudo { + if s.Host.IsWindows() { + s.T().Skip("sudo not supported on windows") + return + } + s.fsys = s.Host.SudoFsys() + } else { + s.fsys = s.Host.Fsys() + } +} + +func (s *FsysSuite) TestMkdir() { + s.T().Log("testmkdir") + testPath := s.TempPath("test", "subdir") + defer func() { + _ = s.fsys.RemoveAll(testPath) + }() + s.Run("Create directory", func() { + s.T().Log("mkdirall") + s.Require().NoError(s.fsys.MkDirAll(testPath, 0755)) + }) + s.Run("Verify directory exists", func() { + s.T().Log("stat") + stat, err := s.fsys.Stat(testPath) + s.Require().NoError(err) + s.Run("Check permissions", func() { + if s.Host.IsWindows() { + s.T().Skip("Windows does not support chmod permissions") + } + s.Equal(os.FileMode(0755), stat.Mode().Perm()) + parent, err := s.fsys.Stat(s.TempPath("test")) + s.Require().NoError(err) + s.Equal(os.FileMode(0755), parent.Mode().Perm()) + }) + }) +} + +func (s *FsysSuite) TestRemove() { + testPath := s.TempPath("test", "subdir") + s.Run("Create directory", func() { + s.Require().NoError(s.fsys.MkDirAll(testPath, 0755)) + }) + s.Run("Remove directory", func() { + s.Require().NoError(s.fsys.RemoveAll(testPath)) + }) + s.Run("Verify directory does not exist", func() { + stat, err := s.fsys.Stat(testPath) + s.Nil(stat) + s.Error(err) + s.True(os.IsNotExist(err)) + }) + s.Run("Remove parent directory", func() { + s.Require().NoError(s.fsys.RemoveAll(s.TempPath("test"))) + }) + s.Run("Verify parent directory does not exist", func() { + stat, err := s.fsys.Stat(s.TempPath("test")) + s.Nil(stat) + s.Error(err) + s.True(os.IsNotExist(err)) + }) +} + +func (s *FsysSuite) TestReadWriteFile() { + for _, testFileSize := range []int64{ + int64(500), // less than one block on most filesystems + int64(1 << (10 * 2)), // exactly 1MB + int64(4096), // exactly one block on most filesystems + int64(4097), // plus 1 + } { + s.Run(fmt.Sprintf("File size %d", testFileSize), func() { + fn := s.TempPath() + + origin := io.LimitReader(rand.Reader, testFileSize) + shasum := sha256.New() + reader := io.TeeReader(origin, shasum) + + defer func() { + _ = s.fsys.Remove(fn) + }() + s.Run("Write file", func() { + f, err := s.fsys.OpenFile(fn, os.O_CREATE|os.O_WRONLY, 0644) + s.Require().NoError(err) + n, err := io.Copy(f, reader) + s.Require().NoError(err) + s.Equal(testFileSize, n) + s.Require().NoError(f.Close()) + }) + + s.Run("Verify file size", func() { + stat, err := s.fsys.Stat(fn) + s.Require().NoError(err) + s.Equal(testFileSize, stat.Size()) + }) + + s.Run("Verify file sha256", func() { + sum, err := s.fsys.Sha256(fn) + s.Require().NoError(err) + s.Equal(hex.EncodeToString(shasum.Sum(nil)), sum) + }) + + readSha := sha256.New() + s.Run("Read file", func() { + f, err := s.fsys.Open(fn) + s.Require().NoError(err) + n, err := io.Copy(readSha, f) + s.Require().NoError(err) + s.Equal(testFileSize, n) + s.Require().NoError(f.Close()) + }) + + s.Run("Verify read file sha256", func() { + s.Equal(shasum.Sum(nil), readSha.Sum(nil)) + }) + }) + } +} + +type RepeatReader struct { + data []byte +} + +func (r *RepeatReader) Read(p []byte) (n int, err error) { + for i := range p { + p[i] = r.data[i%len(r.data)] + } + return len(p), nil +} + +func (s *FsysSuite) TestSeek() { + fn := s.TempPath() + reference := bytes.Repeat([]byte{'a'}, 1024) + defer func() { + _ = s.fsys.Remove(fn) + }() + f, err := s.fsys.OpenFile(fn, os.O_CREATE|os.O_WRONLY, 0644) + s.Require().NoError(err) + n, err := io.Copy(f, bytes.NewReader(bytes.Repeat([]byte{'a'}, 1024))) + s.Require().NoError(err) + s.Equal(int64(1024), n) + s.Require().NoError(f.Close()) + + s.Run("Verify contents", func() { + f, err := s.fsys.Open(fn) + s.Require().NoError(err) + b, err := io.ReadAll(f) + s.Require().NoError(err) + s.Equal(1024, len(b)) + s.Require().NoError(f.Close()) + s.Equal(reference, b) + }) + s.Run("Alter file beginning", func() { + f, err := s.fsys.OpenFile(fn, os.O_WRONLY, 0644) + s.Require().NoError(err) + np, err := f.Seek(0, io.SeekStart) + s.Require().NoError(err) + s.Equal(int64(0), np) + n, err := io.Copy(f, bytes.NewReader(bytes.Repeat([]byte{'b'}, 256))) + s.Require().NoError(err) + s.Equal(int64(256), n) + s.Require().NoError(f.Close()) + }) + copy(reference[0:256], bytes.Repeat([]byte{'b'}, 256)) + s.Run("Verify contents after file beginning altered", func() { + f, err := s.fsys.Open(fn) + s.Require().NoError(err) + b, err := io.ReadAll(f) + s.Require().NoError(err) + s.Equal(1024, len(b)) + s.Require().NoError(f.Close()) + s.Equal(reference, b) + }) + s.Run("Alter file ending", func() { + f, err := s.fsys.OpenFile(fn, os.O_WRONLY, 0644) + s.Require().NoError(err) + np, err := f.Seek(-256, io.SeekEnd) + s.Require().NoError(err) + s.Equal(int64(768), np) + n, err := io.Copy(f, bytes.NewReader(bytes.Repeat([]byte{'c'}, 256))) + s.Require().NoError(err) + s.Equal(int64(256), n) + s.Require().NoError(f.Close()) + }) + copy(reference[768:1024], bytes.Repeat([]byte{'c'}, 256)) + s.Run("Verify contents after file ending altered", func() { + f, err := s.fsys.Open(fn) + s.Require().NoError(err) + b, err := io.ReadAll(f) + s.Require().NoError(err) + s.Equal(1024, len(b)) + s.Require().NoError(f.Close()) + s.Equal(reference, b) + }) + s.Run("Alter file middle", func() { + f, err := s.fsys.OpenFile(fn, os.O_WRONLY, 0644) + s.Require().NoError(err) + np, err := f.Seek(256, io.SeekStart) + s.Require().NoError(err) + s.Equal(int64(256), np) + n, err := io.Copy(f, bytes.NewReader(bytes.Repeat([]byte{'d'}, 512))) + s.Require().NoError(err) + s.Equal(int64(512), n) + s.Require().NoError(f.Close()) + }) + copy(reference[256:768], bytes.Repeat([]byte{'d'}, 512)) + s.Run("Verify contents after file middle altered", func() { + f, err := s.fsys.Open(fn) + s.Require().NoError(err) + b, err := io.ReadAll(f) + s.Require().NoError(err) + s.Equal(1024, len(b)) + s.Require().NoError(f.Close()) + s.Equal(reference, b) + }) +} + +func (s *FsysSuite) TestReadDir() { + defer func() { + _ = s.fsys.RemoveAll(s.TempPath("test")) + }() + s.Run("Create directory", func() { + s.Require().NoError(s.fsys.MkDirAll(s.TempPath("test"), 0755)) + }) + s.Run("Create files", func() { + for _, fn := range []string{s.TempPath("test", "subdir", "nestedfile"), s.TempPath("test", "file")} { + s.Require().NoError(s.fsys.MkDirAll(pathDir(fn), 0755)) + f, err := s.fsys.OpenFile(fn, os.O_CREATE|os.O_WRONLY, 0644) + s.Require().NoError(err) + n, err := f.Write([]byte("test")) + s.Require().NoError(err) + s.Equal(4, n) + s.Require().NoError(f.Close()) + } + }) + + s.Run("Read directory", func() { + dir, err := s.fsys.OpenFile(s.TempPath("test"), os.O_RDONLY, 0644) + s.Require().NoError(err) + s.Require().NotNil(dir) + readDirFile, ok := dir.(fs.ReadDirFile) + s.Require().True(ok) + entries, err := readDirFile.ReadDir(-1) + s.Require().NoError(err) + s.Require().Len(entries, 2) + s.Equal("subdir", entries[0].Name()) + s.True(entries[0].IsDir()) + s.Equal("file", entries[1].Name()) + s.False(entries[1].IsDir()) + s.Require().NoError(dir.Close()) + }) + + s.Run("Walkdir", func() { + var entries []string + s.Require().NoError(fs.WalkDir(s.fsys, s.TempPath("test"), func(path string, d fs.DirEntry, err error) error { + s.Require().NoError(err) + info, err := d.Info() + s.Require().NoError(err) + if strings.HasSuffix(path, "file") { + s.False(info.IsDir()) + s.True(info.Mode().IsRegular()) + } else { + s.True(info.IsDir()) + } + entries = append(entries, path) + return nil + })) + s.Len(entries, 4) + for _, item := range []string{ + s.TempPath("test"), + s.TempPath("test/subdir"), + s.TempPath("test/subdir/nestedfile"), + s.TempPath("test/file"), + } { + s.Contains(entries, item) + } + }) +} diff --git a/test/test.sh b/test/test.sh index 2f99133c..2f0c6499 100755 --- a/test/test.sh +++ b/test/test.sh @@ -4,11 +4,11 @@ RET=0 set -e color_echo() { - echo -e "\033[1;31m$@\033[0m" + echo -e "\033[1;31m$*\033[0m" } ssh_port() { - bootloose show $1 -o json|grep hostPort|grep -oE "[0-9]+" + bootloose show "$1" -o json|grep hostPort|grep -oE "[0-9]+" } sanity_check() { @@ -20,11 +20,11 @@ sanity_check() { docker ps echo "* SSH port: $(ssh_port node0)" echo "* Testing stock ssh" - retry ssh -vvv -o BatchMode=yes -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -i .ssh/identity -p $(ssh_port node0) root@127.0.0.1 echo "test-conn" || return $? + retry ssh -vvv -o BatchMode=yes -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -i .ssh/identity -p "$(ssh_port node0)" root@127.0.0.1 echo "test-conn" || return $? set +e echo "* Testing bootloose ssh" bootloose ssh root@node0 echo test-conn | grep -q test-conn - local exit_code=$? + exit_code=$? set -e make clean RET=$exit_code @@ -35,8 +35,8 @@ rig_test_key_from_path() { make create-host mv .ssh/identity .ssh/identity2 set +e - ./rigtest -host 127.0.0.1:$(ssh_port node0) -user root -keypath .ssh/identity2 - local exit_code=$? + go test -v ./ -args -host 127.0.0.1 -port "$(ssh_port node0)" -user root -ssh-keypath .ssh/identity2 + exit_code=$? set -e RET=$exit_code } @@ -44,14 +44,14 @@ rig_test_key_from_path() { rig_test_agent_with_public_key() { color_echo "- Testing connection using agent and providing a path to public key" make create-host - eval $(ssh-agent -s) + eval "$(ssh-agent -s)" ssh-add .ssh/identity rm -f .ssh/identity set +e - HOME=$(pwd) SSH_AUTH_SOCK=$SSH_AUTH_SOCK ./rigtest -host 127.0.0.1:$(ssh_port node0) -user root -keypath .ssh/identity.pub -connect - local exit_code=$? + HOME=$(pwd) SSH_AUTH_SOCK=$SSH_AUTH_SOCK go test -v ./ -args -host 127.0.0.1 -port "$(ssh_port node0)" -user root -ssh-keypath .ssh/identity.pub -connect + exit_code=$? set -e - kill $SSH_AGENT_PID + kill "$SSH_AGENT_PID" export SSH_AGENT_PID= export SSH_AUTH_SOCK= RET=$exit_code @@ -60,7 +60,7 @@ rig_test_agent_with_public_key() { rig_test_agent_with_private_key() { color_echo "- Testing connection using agent and providing a path to protected private key" make create-host KEY_PASSPHRASE=testPhrase - eval $(ssh-agent -s) + eval "$(ssh-agent -s)" expect -c ' spawn ssh-add .ssh/identity expect "?:" @@ -69,8 +69,8 @@ rig_test_agent_with_private_key() { ' set +e # path points to a private key, rig should try to look for the .pub for it - HOME=$(pwd) SSH_AUTH_SOCK=$SSH_AUTH_SOCK ./rigtest -host 127.0.0.1:$(ssh_port node0) -user root -keypath .ssh/identity -connect - local exit_code=$? + HOME=$(pwd) SSH_AUTH_SOCK=$SSH_AUTH_SOCK go test -v ./ -args -host 127.0.0.1 -port "$(ssh_port node0)" -user root -ssh-keypath .ssh/identity -connect + exit_code=$? set -e kill $SSH_AGENT_PID export SSH_AGENT_PID= @@ -81,13 +81,13 @@ rig_test_agent_with_private_key() { rig_test_agent() { color_echo "- Testing connection using any key from agent (empty keypath)" make create-host - eval $(ssh-agent -s) + eval "$(ssh-agent -s)" ssh-add .ssh/identity rm -f .ssh/identity set +e ssh-add -l - HOME=. SSH_AUTH_SOCK=$SSH_AUTH_SOCK ./rigtest -host 127.0.0.1:$(ssh_port node0) -user root -keypath "" -connect - local exit_code=$? + HOME=$(pwd) SSH_AUTH_SOCK=$SSH_AUTH_SOCK go test -v ./ -args -host 127.0.0.1 -port "$(ssh_port node0)" -user root -ssh-keypath "" -connect + exit_code=$? set -e kill $SSH_AGENT_PID export SSH_AGENT_PID= @@ -99,12 +99,12 @@ rig_test_ssh_config() { color_echo "- Testing getting identity path from ssh config" make create-host mv .ssh/identity .ssh/identity2 - echo "Host 127.0.0.1:$(ssh_port node0)" > .ssh/config - echo " IdentityFile .ssh/identity2" >> .ssh/config + echo "Host 127.0.0.1" > .ssh/config + echo " IdentityFile $(pwd)/.ssh/identity2" >> .ssh/config chmod 0600 .ssh/config set +e - HOME=. SSH_CONFIG=.ssh/config ./rigtest -host 127.0.0.1:$(ssh_port node0) -user root -connect - local exit_code=$? + HOME=$(pwd) go test -v ./ -args -ssh-configpath .ssh/config -host 127.0.0.1 -port "$(ssh_port node0)" -user root -connect + exit_code=$? set -e RET=$exit_code } @@ -112,14 +112,18 @@ rig_test_ssh_config() { rig_test_ssh_config_strict() { color_echo "- Testing StrictHostkeyChecking=yes in ssh config" make create-host - local addr="127.0.0.1:$(ssh_port node0)" - echo "Host ${addr}" > .ssh/config - echo " IdentityFile .ssh/identity" >> .ssh/config + port="$(ssh_port node0)" + echo "Host testhost" > .ssh/config + echo " User root" >> .ssh/config + echo " HostName 127.0.0.1" >> .ssh/config + echo " Port ${port}" >> .ssh/config + echo " IdentityFile $(pwd)/.ssh/identity" >> .ssh/config echo " UserKnownHostsFile $(pwd)/.ssh/known" >> .ssh/config + echo " StrictHostKeyChecking yes" >> .ssh/config cat .ssh/config set +e - HOME=. SSH_CONFIG=.ssh/config ./rigtest -host "${addr}" -user root -connect - local exit_code=$? + HOME=$(pwd) go test -v ./ -args -ssh-configpath .ssh/config -host testhost -connect + exit_code=$? set -e if [ $exit_code -ne 0 ]; then echo " * Failed first checkpoint" @@ -129,10 +133,11 @@ rig_test_ssh_config_strict() { echo " * Passed first checkpoint" cat .ssh/known # modify the known hosts file to make it mismatch - echo "${addr} ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBBgejI9UJnRY/i4HNM/os57oFcRjE77gEbVfUkuGr5NRh3N7XxUnnBKdzrAiQNPttUjKmUm92BN7nCUxbwsoSPw=" > .ssh/known + echo "[127.0.0.1]:$port ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBBgejI9UJnRY/i4HNM/os57oFcRjE77gEbVfUkuGr5NRh3N7XxUnnBKdzrAiQNPttUjKmUm92BN7nCUxbwsoSPw=" > .ssh/known + echo "[127.0.0.1]:$port ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBGZKwBdFeIPlDWe7otNy4E2Im8+GnQtsukJ5dIuzDGb" >> .ssh/known cat .ssh/known set +e - HOME=. SSH_CONFIG=.ssh/config ./rigtest -host "${addr}" -user root -connect + HOME=$(pwd) go test -v ./ -args -ssh-configpath .ssh/config -host testhost -connect exit_code=$? set -e @@ -148,22 +153,27 @@ rig_test_ssh_config_strict() { rig_test_ssh_config_no_strict() { color_echo "- Testing StrictHostkeyChecking=no in ssh config" make create-host - local addr="127.0.0.1:$(ssh_port node0)" - echo "Host ${addr}" > .ssh/config + port="$(ssh_port node0)" + echo "Host testhost" > .ssh/config + echo " User root" >> .ssh/config + echo " HostName 127.0.0.1" >> .ssh/config + echo " Port ${port}" >> .ssh/config echo " UserKnownHostsFile $(pwd)/.ssh/known" >> .ssh/config echo " StrictHostKeyChecking no" >> .ssh/config + cat .ssh/config set +e - HOME=. SSH_CONFIG=.ssh/config ./rigtest -host "${addr}" -user root -connect - local exit_code=$? + HOME=$(pwd) go test -v ./ -args -ssh-configpath .ssh/config -host testhost -connect + exit_code=$? set -e if [ $? -ne 0 ]; then RET=1 return fi # modify the known hosts file to make it mismatch - echo "${addr} ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBBgejI9UJnRY/i4HNM/os57oFcRjE77gEbVfUkuGr5NRh3N7XxUnnBKdzrAiQNPttUjKmUm92BN7nCUxbwsoSPw=" > .ssh/known + echo "[127.0.0.1]:$port ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBBgejI9UJnRY/i4HNM/os57oFcRjE77gEbVfUkuGr5NRh3N7XxUnnBKdzrAiQNPttUjKmUm92BN7nCUxbwsoSPw=" > .ssh/known + echo "[127.0.0.1]:$port ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBGZKwBdFeIPlDWe7otNy4E2Im8+GnQtsukJ5dIuzDGb" >> .ssh/known set +e - HOME=. SSH_CONFIG=.ssh/config ./rigtest -host "${addr}" -user root -connect + HOME=$(pwd) go test -v ./ -args -ssh-configpath .ssh/config -host testhost -connect exit_code=$? set -e RET=$exit_code @@ -174,8 +184,8 @@ rig_test_key_from_memory() { make create-host mv .ssh/identity .ssh/identity2 set +e - ./rigtest -host 127.0.0.1:$(ssh_port node0) -user root -ssh-private-key "$(cat .ssh/identity2)" -connect - local exit_code=$? + go test -v ./ -args -host 127.0.0.1 -port "$(ssh_port node0)" -user root -ssh-private-key "$(cat .ssh/identity2)" -connect + exit_code=$? set -e RET=$exit_code } @@ -185,36 +195,9 @@ rig_test_key_from_default_location() { make create-host mv .ssh/identity .ssh/id_ecdsa set +e - HOME=$(pwd) ./rigtest -host 127.0.0.1:$(ssh_port node0) -user root -connect - local exit_code=$? - set -e - RET=$exit_code -} - -rig_test_protected_key_from_path() { - color_echo "- Testing regular keypath to encrypted key, two hosts" - make create-host KEY_PASSPHRASE=testPhrase REPLICAS=2 - set +e - ssh_port node0 > .ssh/port_A - ssh_port node1 > .ssh/port_B - expect -c ' - - set fp [open .ssh/port_A r] - set PORTA [read -nonewline $fp] - close $fp - set fp [open .ssh/port_B r] - set PORTB [read -nonewline $fp] - close $fp - - spawn ./rigtest -host 127.0.0.1:$PORTA,127.0.0.1:$PORTB -user root -keypath .ssh/identity -askpass true -connect - expect "Password:" - send "testPhrase\n" - expect eof" - ' $port1 $port2 - local exit_code=$? + HOME=$(pwd) go test -v ./ -args -host 127.0.0.1 -port "$(ssh_port node0)" -user root -connect + exit_code=$? set -e - rm bootloose.yaml - make delete-host REPLICAS=2 RET=$exit_code } @@ -257,7 +240,7 @@ EOF return 0 } - env -i HOME="$(pwd)" ./rigtest -host 127.0.0.1:"$sshPort" -user rigtest-user -keypath .ssh/identity + HOME="$(pwd)" go test -v ./ -args -host 127.0.0.1 -port "$sshPort" -user rigtest-user -ssh-keypath .ssh/identity } rig_test_openssh_client() { @@ -268,9 +251,10 @@ rig_test_openssh_client() { echo " Port $(ssh_port node0)" >> .ssh/config echo " User root" >> .ssh/config echo " IdentityFile $(pwd)/.ssh/identity" >> .ssh/config + cat .ssh/config set +e - SSH_CONFIG=.ssh/config ./rigtest -host testhost -proto openssh -user "" - local exit_code=$? + go test -v ./ -args -ssh-configpath .ssh/config -host testhost -protocol openssh -user "" + exit_code=$? set -e RET=$exit_code } @@ -283,9 +267,10 @@ rig_test_openssh_client_no_multiplex() { echo " Port $(ssh_port node0)" >> .ssh/config echo " User root" >> .ssh/config echo " IdentityFile $(pwd)/.ssh/identity" >> .ssh/config + cat .ssh/config set +e - SSH_CONFIG=.ssh/config ./rigtest -host testhost -proto openssh -user "" -ssh-multiplex=false - local exit_code=$? + go test -v ./ -args -ssh-configpath .ssh/config -host testhost -protocol openssh -user "" -openssh-multiplex=false + exit_code=$? set -e RET=$exit_code } @@ -310,7 +295,6 @@ for test in $(declare -F|grep rig_test_|cut -d" " -f3); do continue fi make clean - make rigtest color_echo "\n###########################################################" RET=0 $test || RET=$? diff --git a/winrm.go b/winrm.go index 34384535..753bb8b1 100644 --- a/winrm.go +++ b/winrm.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "os" + "strings" "sync" "time" @@ -192,49 +193,35 @@ type Command struct { stdin io.ReadCloser stdout io.Writer stderr io.Writer + wg sync.WaitGroup } // Wait blocks until the command finishes func (c *Command) Wait() error { - var wg sync.WaitGroup defer c.sh.Close() - if c.stdin == nil { - c.cmd.Stdin.Close() - } else { - wg.Add(1) - go func() { - defer c.cmd.Stdin.Close() - defer wg.Done() - log.Debugf("copying data to stdin") - _, err := io.Copy(c.cmd.Stdin, c.stdin) - if err != nil { - log.Errorf("copying data to command stdin failed: %v", err) - } - }() - } - wg.Add(2) - go func() { - defer wg.Done() - _, _ = io.Copy(c.stdout, c.cmd.Stdout) - }() - go func() { - defer wg.Done() - _, _ = io.Copy(c.stderr, c.cmd.Stderr) - }() + defer c.cmd.Close() + c.wg.Wait() c.cmd.Wait() log.Debugf("command finished") var err error if c.cmd.ExitCode() != 0 { err = fmt.Errorf("%w: exit code %d", ErrCommandFailed, c.cmd.ExitCode()) } - wg.Wait() return err } +// Close terminates the command +func (c *Command) Close() error { + if err := c.cmd.Close(); err != nil { + return fmt.Errorf("close command: %w", err) + } + return nil +} + // ExecStreams executes a command on the remote host and uses the passed in streams for stdin, stdout and stderr. It returns a Waiter with a .Wait() function that // blocks until the command finishes and returns an error if the exit code is not zero. -func (c *WinRM) ExecStreams(cmd string, stdin io.ReadCloser, stdout, stderr io.Writer, opts ...exec.Option) (waiter, error) { +func (c *WinRM) ExecStreams(cmd string, stdin io.ReadCloser, stdout, stderr io.Writer, opts ...exec.Option) (exec.Waiter, error) { if c.client == nil { return nil, ErrNotConnected } @@ -256,7 +243,42 @@ func (c *WinRM) ExecStreams(cmd string, stdin io.ReadCloser, stdout, stderr io.W if err != nil { return nil, fmt.Errorf("%w: execute command: %w", ErrCommandFailed, err) } - return &Command{sh: shell, cmd: proc, stdin: stdin, stdout: stdout, stderr: stderr}, nil + res := &Command{sh: shell, cmd: proc, stdin: stdin, stdout: stdout, stderr: stderr} + if res.stdin == nil { + proc.Stdin.Close() + } else { + res.wg.Add(1) + go func() { + defer res.wg.Done() + log.Debugf("copying data to command stdin") + n, err := io.Copy(res.cmd.Stdin, res.stdin) + if err != nil { + log.Errorf("copying data to command stdin failed: %v", err) + } + log.Debugf("finished copying %d bytes to stdin", n) + }() + } + res.wg.Add(2) + started := time.Now() + go func() { + defer res.wg.Done() + log.Debugf("copying data from command stdout") + n, err := io.Copy(res.stdout, res.cmd.Stdout) + if err != nil { + log.Errorf("copying data from command stdout failed after %s: %v", time.Since(started), err) + } + log.Debugf("finished copying %d bytes from stdout", n) + }() + go func() { + defer res.wg.Done() + log.Debugf("copying data from command stderr") + n, err := io.Copy(res.stderr, res.cmd.Stderr) + if err != nil { + log.Errorf("copying data from command stderr failed after %s: %v", time.Since(started), err) + } + log.Debugf("finished copying %d bytes from stderr", n) + }() + return res, nil } // Exec executes a command on the host @@ -274,6 +296,7 @@ func (c *WinRM) Exec(cmd string, opts ...exec.Option) error { //nolint:cyclop if err != nil { return fmt.Errorf("%w: execute command: %w", ErrCommandFailed, err) } + defer command.Close() var wg sync.WaitGroup @@ -308,7 +331,7 @@ func (c *WinRM) Exec(cmd string, opts ...exec.Option) error { //nolint:cyclop } }() - gotErrors := false + var errors []string wg.Add(1) go func() { @@ -316,26 +339,27 @@ func (c *WinRM) Exec(cmd string, opts ...exec.Option) error { //nolint:cyclop outputScanner := bufio.NewScanner(command.Stderr) for outputScanner.Scan() { - gotErrors = true - execOpts.AddOutput(c.String(), "", outputScanner.Text()+"\n") + msg := outputScanner.Text() + if msg != "" { + errors = append(errors, msg) + execOpts.LogErrorf("%s: %s", c, msg) + } } if err := outputScanner.Err(); err != nil { - gotErrors = true execOpts.LogErrorf("%s: %s", c, err.Error()) } command.Stderr.Close() }() - command.Wait() - wg.Wait() + command.Wait() if ec := command.ExitCode(); ec > 0 { return fmt.Errorf("%w: non-zero exit code: %d", ErrCommandFailed, ec) } - if !execOpts.AllowWinStderr && gotErrors { - return fmt.Errorf("%w: received data in stderr", ErrCommandFailed) + if !execOpts.AllowWinStderr && len(errors) > 0 { + return fmt.Errorf("%w: received data in stderr: %s", ErrCommandFailed, strings.Join(errors, "\n")) } return nil