diff --git a/client.go b/client.go index ee7d6f8..bde8638 100644 --- a/client.go +++ b/client.go @@ -9,7 +9,6 @@ package scp import ( "bytes" "context" - "errors" "fmt" "io" "io/ioutil" @@ -85,13 +84,24 @@ func (a *Client) SSHClient() *ssh.Client { } // CopyFromFile copies the contents of an os.File to a remote location, it will get the length of the file by looking it up from the filesystem. -func (a *Client) CopyFromFile(ctx context.Context, file os.File, remotePath string, permissions string) error { +func (a *Client) CopyFromFile( + ctx context.Context, + file os.File, + remotePath string, + permissions string, +) error { return a.CopyFromFilePassThru(ctx, file, remotePath, permissions, nil) } // CopyFromFilePassThru copies the contents of an os.File to a remote location, it will get the length of the file by looking it up from the filesystem. // Access copied bytes by providing a PassThru reader factory. -func (a *Client) CopyFromFilePassThru(ctx context.Context, file os.File, remotePath string, permissions string, passThru PassThru) error { +func (a *Client) CopyFromFilePassThru( + ctx context.Context, + file os.File, + remotePath string, + permissions string, + passThru PassThru, +) error { stat, err := file.Stat() if err != nil { return fmt.Errorf("failed to stat file: %w", err) @@ -101,21 +111,39 @@ func (a *Client) CopyFromFilePassThru(ctx context.Context, file os.File, remoteP // CopyFile copies the contents of an io.Reader to a remote location, the length is determined by reading the io.Reader until EOF // if the file length in know in advance please use "Copy" instead. -func (a *Client) CopyFile(ctx context.Context, fileReader io.Reader, remotePath string, permissions string) error { +func (a *Client) CopyFile( + ctx context.Context, + fileReader io.Reader, + remotePath string, + permissions string, +) error { return a.CopyFilePassThru(ctx, fileReader, remotePath, permissions, nil) } // CopyFilePassThru copies the contents of an io.Reader to a remote location, the length is determined by reading the io.Reader until EOF // if the file length in know in advance please use "Copy" instead. // Access copied bytes by providing a PassThru reader factory. -func (a *Client) CopyFilePassThru(ctx context.Context, fileReader io.Reader, remotePath string, permissions string, passThru PassThru) error { +func (a *Client) CopyFilePassThru( + ctx context.Context, + fileReader io.Reader, + remotePath string, + permissions string, + passThru PassThru, +) error { contentsBytes, err := ioutil.ReadAll(fileReader) if err != nil { return fmt.Errorf("failed to read all data from reader: %w", err) } bytesReader := bytes.NewReader(contentsBytes) - return a.CopyPassThru(ctx, bytesReader, remotePath, permissions, int64(len(contentsBytes)), passThru) + return a.CopyPassThru( + ctx, + bytesReader, + remotePath, + permissions, + int64(len(contentsBytes)), + passThru, + ) } // wait waits for the waitgroup for the specified max timeout. @@ -139,27 +167,36 @@ func wait(wg *sync.WaitGroup, ctx context.Context) error { // checkResponse checks the response it reads from the remote, and will return a single error in case // of failure. func checkResponse(r io.Reader) error { - response, err := ParseResponse(r) + _, err := ParseResponse(r, nil) if err != nil { return err } - if response.IsFailure() { - return errors.New(response.GetMessage()) - } - return nil } // Copy copies the contents of an io.Reader to a remote location. -func (a *Client) Copy(ctx context.Context, r io.Reader, remotePath string, permissions string, size int64) error { +func (a *Client) Copy( + ctx context.Context, + r io.Reader, + remotePath string, + permissions string, + size int64, +) error { return a.CopyPassThru(ctx, r, remotePath, permissions, size, nil) } // CopyPassThru copies the contents of an io.Reader to a remote location. // Access copied bytes by providing a PassThru reader factory -func (a *Client) CopyPassThru(ctx context.Context, r io.Reader, remotePath string, permissions string, size int64, passThru PassThru) error { +func (a *Client) CopyPassThru( + ctx context.Context, + r io.Reader, + remotePath string, + permissions string, + size int64, + passThru PassThru, +) error { session, err := a.sshClient.NewSession() if err != nil { return fmt.Errorf("Error creating ssh session in copy to remote: %v", err) @@ -264,7 +301,7 @@ func (a *Client) CopyPassThru(ctx context.Context, r io.Reader, remotePath strin // CopyFromRemote copies a file from the remote to the local file given by the `file` // parameter. Use `CopyFromRemotePassThru` if a more generic writer -// is desired instead of writing directly to a file on the file system.? +// is desired instead of writing directly to a file on the file system. func (a *Client) CopyFromRemote(ctx context.Context, file *os.File, remotePath string) error { return a.CopyFromRemotePassThru(ctx, file, remotePath, nil) } @@ -272,22 +309,51 @@ func (a *Client) CopyFromRemote(ctx context.Context, file *os.File, remotePath s // CopyFromRemotePassThru copies a file from the remote to the given writer. The passThru parameter can be used // to keep track of progress and how many bytes that were download from the remote. // `passThru` can be set to nil to disable this behaviour. -func (a *Client) CopyFromRemotePassThru(ctx context.Context, w io.Writer, remotePath string, passThru PassThru) error { +func (a *Client) CopyFromRemotePassThru( + ctx context.Context, + w io.Writer, + remotePath string, + passThru PassThru, +) error { + _, err := a.copyFromRemote(ctx, w, remotePath, passThru, false) + + return err +} + +// CopyFroRemoteFileInfos copies a file from the remote to a given writer and return a FileInfos struct +// containing information about the file such as permissions, the file size, modification time and access time +func (a *Client) CopyFromRemoteFileInfos( + ctx context.Context, + w io.Writer, + remotePath string, + passThru PassThru, +) (*FileInfos, error) { + return a.copyFromRemote(ctx, w, remotePath, passThru, true) +} + +func (a *Client) copyFromRemote( + ctx context.Context, + w io.Writer, + remotePath string, + passThru PassThru, + preserveFileTimes bool, +) (*FileInfos, error) { session, err := a.sshClient.NewSession() if err != nil { - return fmt.Errorf("Error creating ssh session in copy from remote: %v", err) + return nil, fmt.Errorf("Error creating ssh session in copy from remote: %v", err) } defer session.Close() wg := sync.WaitGroup{} errCh := make(chan error, 4) + var fileInfos *FileInfos wg.Add(1) go func() { var err error defer func() { - // NOTE: this might send an already sent error another time, but since we only receive opne, this is fine. On the "happy-path" of this function, the error will be `nil` therefore completing the "err<-errCh" at the bottom of the function. + // NOTE: this might send an already sent error another time, but since we only receive one, this is fine. On the "happy-path" of this function, the error will be `nil` therefore completing the "err<-errCh" at the bottom of the function. errCh <- err // We must unblock the go routine first as we block on reading the channel later wg.Done() @@ -307,7 +373,11 @@ func (a *Client) CopyFromRemotePassThru(ctx context.Context, w io.Writer, remote } defer in.Close() - err = session.Start(fmt.Sprintf("%s -f %q", a.RemoteBinary, remotePath)) + if preserveFileTimes { + err = session.Start(fmt.Sprintf("%s -pf %q", a.RemoteBinary, remotePath)) + } else { + err = session.Start(fmt.Sprintf("%s -f %q", a.RemoteBinary, remotePath)) + } if err != nil { errCh <- err return @@ -319,21 +389,13 @@ func (a *Client) CopyFromRemotePassThru(ctx context.Context, w io.Writer, remote return } - res, err := ParseResponse(r) + fileInfo, err := ParseResponse(r, in) if err != nil { errCh <- err return } - if res.IsFailure() { - errCh <- errors.New(res.GetMessage()) - return - } - infos, err := res.ParseFileInfos() - if err != nil { - errCh <- err - return - } + fileInfos = fileInfo err = Ack(in) if err != nil { @@ -342,10 +404,10 @@ func (a *Client) CopyFromRemotePassThru(ctx context.Context, w io.Writer, remote } if passThru != nil { - r = passThru(r, infos.Size) + r = passThru(r, fileInfo.Size) } - _, err = CopyN(w, r, infos.Size) + _, err = CopyN(w, r, fileInfo.Size) if err != nil { errCh <- err return @@ -371,11 +433,12 @@ func (a *Client) CopyFromRemotePassThru(ctx context.Context, w io.Writer, remote } if err := wait(&wg, ctx); err != nil { - return err + return nil, err } + finalErr := <-errCh close(errCh) - return finalErr + return fileInfos, finalErr } func (a *Client) Close() { diff --git a/configurer.go b/configurer.go index 3f1c16e..db1d9f6 100644 --- a/configurer.go +++ b/configurer.go @@ -78,6 +78,6 @@ func (c *ClientConfigurer) Create() Client { Timeout: c.timeout, RemoteBinary: c.remoteBinary, sshClient: c.sshClient, - closeHandler: EmptyHandler{}, + closeHandler: EmptyHandler{}, } } diff --git a/protocol.go b/protocol.go index ca19d71..6871c29 100644 --- a/protocol.go +++ b/protocol.go @@ -9,42 +9,30 @@ package scp import ( "bufio" "errors" + "fmt" "io" "strconv" "strings" ) -type ResponseType = uint8 +type ResponseType = byte const ( Ok ResponseType = 0 Warning ResponseType = 1 Error ResponseType = 2 + Create ResponseType = 'C' + Time ResponseType = 'T' ) -// Response represent a response from the SCP command. -// There are tree types of responses that the remote can send back: -// ok, warning and error -// -// The difference between warning and error is that the connection is not closed by the remote, -// however, a warning can indicate a file transfer failure (such as invalid destination directory) -// and such be handled as such. -// -// All responses except for the `Ok` type always have a message (although these can be empty) -// -// The remote sends a confirmation after every SCP command, because a failure can occur after every -// command, the response should be read and checked after sending them. -type Response struct { - Type ResponseType - Message string -} - // ParseResponse reads from the given reader (assuming it is the output of the remote) and parses it into a Response structure. -func ParseResponse(reader io.Reader) (Response, error) { +func ParseResponse(reader io.Reader, writer io.Writer) (*FileInfos, error) { + fileInfos := NewFileInfos() + buffer := make([]uint8, 1) _, err := reader.Read(buffer) if err != nil { - return Response{}, err + return fileInfos, err } responseType := buffer[0] @@ -53,61 +41,154 @@ func ParseResponse(reader io.Reader) (Response, error) { bufferedReader := bufio.NewReader(reader) message, err = bufferedReader.ReadString('\n') if err != nil { - return Response{}, err + return fileInfos, err } - } - return Response{responseType, message}, nil -} + if responseType == Warning || responseType == Error { + return fileInfos, errors.New(message) + } -func (r *Response) IsOk() bool { - return r.Type == Ok -} + // Exit early because we're only interested in the ok response + if responseType == Ok { + return fileInfos, nil + } -func (r *Response) IsWarning() bool { - return r.Type == Warning -} + if !(responseType == Create || responseType == Time) { + return fileInfos, errors.New( + fmt.Sprintf( + "Message does not follow scp protocol: %s\n Cmmmm or T 0 0", + message, + ), + ) + } -// IsError returns true when the remote responded with an error. -func (r *Response) IsError() bool { - return r.Type == Error -} + if responseType == Time { + err = ParseFileTime(message, fileInfos) + if err != nil { + return nil, err + } + + // A custom ssh server can send both time, permissions and size information at once + // without needing an Ack response. Example: wish from charmbracelet while using their default scp implementation + // If the buffer is empty, then it's likely the default implementation for ssh, so send Ack + if bufferedReader.Buffered() == 0 { + err = Ack(writer) + if err != nil { + return fileInfos, err + } + } + + message, err = bufferedReader.ReadString('\n') + + if err != nil { + return fileInfos, err + } + + responseType = message[0] + } -// IsFailure returns true when the remote answered with a warning or an error. -func (r *Response) IsFailure() bool { - return r.IsWarning() || r.IsError() -} + if responseType == Create { + err = ParseFileInfos(message, fileInfos) + if err != nil { + return nil, err + } + } + } -// GetMessage returns the message the remote sent back. -func (r *Response) GetMessage() string { - return r.Message + return fileInfos, nil } type FileInfos struct { Message string Filename string - Permissions string + Permissions uint32 Size int64 + Atime int64 + Mtime int64 +} + +func NewFileInfos() *FileInfos { + return &FileInfos{} +} + +func (fileInfos *FileInfos) Update(new *FileInfos) { + if new == nil { + return + } + if new.Filename != "" { + fileInfos.Filename = new.Filename + } + if new.Permissions != 0 { + fileInfos.Permissions = new.Permissions + } + if new.Size != 0 { + fileInfos.Size = new.Size + } + if new.Atime != 0 { + fileInfos.Atime = new.Atime + } + if new.Mtime != 0 { + fileInfos.Mtime = new.Mtime + } } -func (r *Response) ParseFileInfos() (*FileInfos, error) { - message := strings.ReplaceAll(r.Message, "\n", "") - parts := strings.Split(message, " ") +func ParseFileInfos(message string, fileInfos *FileInfos) error { + processMessage := strings.ReplaceAll(message, "\n", "") + parts := strings.Split(processMessage, " ") if len(parts) < 3 { - return nil, errors.New("unable to parse message as file infos") + return errors.New("unable to parse Chmod protocol") } - size, err := strconv.Atoi(parts[1]) + permissions, err := strconv.ParseUint(parts[0][1:], 0, 32) if err != nil { - return nil, err + return err } - return &FileInfos{ - Message: r.Message, - Permissions: parts[0], - Size: int64(size), + size, err := strconv.ParseInt(parts[1], 10, 64) + if err != nil { + return err + } + + fileInfos.Update(&FileInfos{ Filename: parts[2], - }, nil + Permissions: uint32(permissions), + Size: int64(size), + }) + + return nil +} + +func ParseFileTime( + message string, + fileInfos *FileInfos, +) error { + processMessage := strings.ReplaceAll(message, "\n", "") + parts := strings.Split(processMessage, " ") + if len(parts) < 3 { + return errors.New("unable to parse Time protocol") + } + + if len(parts[0]) != 10 { + return errors.New("length of ATime is not 10") + } + mTime, err := strconv.Atoi(parts[0][0:10]) + if err != nil { + return errors.New("unable to parse ATime component of message") + } + + if len(parts[2]) != 10 { + return errors.New("length of MTime is not 10") + } + aTime, err := strconv.Atoi(parts[2][0:10]) + if err != nil { + return errors.New("unable to parse MTime component of message") + } + + fileInfos.Update(&FileInfos{ + Atime: int64(aTime), + Mtime: int64(mTime), + }) + return nil } // Ack writes an `Ack` message to the remote, does not await its response, a seperate call to ParseResponse is diff --git a/tests/basic_test.go b/tests/basic_test.go index 20f1540..3eff3d8 100644 --- a/tests/basic_test.go +++ b/tests/basic_test.go @@ -3,6 +3,7 @@ package scp import ( "context" "fmt" + "io/fs" "os" "strings" "testing" @@ -207,10 +208,6 @@ func TestDownloadFile(t *testing.T) { client := establishConnection(t) defer client.Close() - // Open a file we can transfer to the remote container. - f, _ := os.Open("./data/input.txt") - defer f.Close() - // Create a local file to write to. f, err := os.OpenFile("./tmp/output.txt", os.O_RDWR|os.O_CREATE, 0777) if err != nil { @@ -237,6 +234,66 @@ func TestDownloadFile(t *testing.T) { } } +func TestDownloadFileInfo(t *testing.T) { + client := establishConnection(t) + defer client.Close() + + // Create a local file to write the remote file to. + f, err := os.OpenFile("./tmp/output.txt", os.O_RDWR|os.O_CREATE, 0777) + if err != nil { + t.Errorf("Couldn't open the output file") + } + defer f.Close() + + // Use a file name with exotic characters and spaces in them. + // If this test works for this, simpler files should not be a problem. + fileInfos, err := client.CopyFromRemoteFileInfos( + context.Background(), + f, + "/input/Exöt1ç download file.txt.txt", + nil, + ) + if err != nil { + t.Errorf("Copy failed from remote: %s", err.Error()) + } + + content, err := os.ReadFile("./tmp/output.txt") + if err != nil { + t.Errorf("Result file could not be read: %s", err) + } + + text := string(content) + expected := "It works for download!\n" + if strings.Compare(text, expected) != 0 { + t.Errorf("Got different text than expected, expected %q got, %q", expected, text) + } + + fileStat, err := os.Stat("./data/Exöt1ç download file.txt.txt") + if err != nil { + t.Errorf("Result file could not be read: %s", err) + } + + if fileInfos.Size != fileStat.Size() { + t.Errorf("File size does not match") + } + + if fs.FileMode(fileInfos.Permissions) != fileStat.Mode() { + t.Errorf( + "File permissions don't match %s vs %s", + fs.FileMode(fileInfos.Permissions), + fileStat.Mode().Perm(), + ) + } + + if fileInfos.Mtime != fileStat.ModTime().Unix() { + t.Errorf( + "File modification time does not match %d vs %d", + fileInfos.Mtime, + fileStat.ModTime().Unix(), + ) + } +} + // TestTimeoutDownload tests that a timeout error is produced if the file is not copied in the given // amount of time. func TestTimeoutDownload(t *testing.T) {