From ec4bb8f6ab8dd0e397e21a680fbd9b7a058cce92 Mon Sep 17 00:00:00 2001 From: Abdelrahman Abdelhafez Date: Thu, 16 Jun 2022 18:19:57 +0200 Subject: [PATCH] :hammer: Control input source & output/error destination --- cmd/combine.go | 32 +++++++++++++++++++++++++------- cmd/root.go | 25 +++++++++++++++++++------ cmd/split.go | 35 +++++++++++++++++++++++++++-------- main.go | 24 ++++++++++++++++++++++-- utils/bufio.go | 24 ++++++++++++++++++++++++ utils/os.go | 5 +++-- 6 files changed, 120 insertions(+), 25 deletions(-) create mode 100644 utils/bufio.go diff --git a/cmd/combine.go b/cmd/combine.go index feb965e..b64d6fa 100644 --- a/cmd/combine.go +++ b/cmd/combine.go @@ -2,6 +2,7 @@ package cmd import ( "fmt" + "io" "github.com/manifoldco/promptui" "github.com/spf13/cobra" @@ -11,7 +12,11 @@ import ( ) // Generates the combine command. -func generateCombineCommand() *cobra.Command { +func generateCombineCommand( + inputSource io.Reader, + outputDestination io.Writer, + errorDestination io.Writer, +) *cobra.Command { // Declare command flag values var thresholdCount int @@ -21,7 +26,12 @@ func generateCombineCommand() *cobra.Command { Short: "Reconstruct a secret from shares", Long: "Reconstructs a secret from shares.", Args: cobra.NoArgs, - Run: runCombineCommand(&thresholdCount), + Run: runCombineCommand( + inputSource, + outputDestination, + errorDestination, + &thresholdCount, + ), } // Define command flags @@ -40,12 +50,15 @@ func generateCombineCommand() *cobra.Command { // Runs the combine command. func runCombineCommand( + inputSource io.Reader, + outputDestination io.Writer, + errorDestination io.Writer, thresholdCount *int, ) func(cmd *cobra.Command, args []string) { return func(cmd *cobra.Command, args []string) { // Validate flag values if *thresholdCount < 2 || *thresholdCount > 255 { - utils.ExitWithError("threshold must be between 2 and 255") + utils.ExitWithError(errorDestination, fmt.Errorf("threshold must be between 2 and 255")) } // Prompt user for shares @@ -53,7 +66,9 @@ func runCombineCommand( for i := 0; i < *thresholdCount; i++ { prompt := promptui.Prompt{ - Label: fmt.Sprintf("Share #%d", i+1), + Stdin: utils.NopReadCloser(inputSource), + Stdout: utils.NopWriteCloser(outputDestination), + Label: fmt.Sprintf("Share #%d", i+1), Validate: func(input string) error { if len(input) == 0 { return fmt.Errorf("share must not be empty") @@ -73,7 +88,7 @@ func runCombineCommand( share, err := prompt.Run() if err != nil { - utils.ExitWithError(err.Error()) + utils.ExitWithError(errorDestination, err) } shares[i] = share @@ -82,10 +97,13 @@ func runCombineCommand( // Reconstruct secret from shares secret, err := shamir.Combine(shares) if err != nil { - utils.ExitWithError(err.Error()) + utils.ExitWithError(errorDestination, err) } // Print secret - fmt.Println(secret) + _, err = fmt.Fprintln(outputDestination, secret) + if err != nil { + utils.ExitWithError(errorDestination, err) + } } } diff --git a/cmd/root.go b/cmd/root.go index b1f8381..d3c3d55 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -1,13 +1,18 @@ package cmd import ( + "io" "strings" "github.com/spf13/cobra" ) -// Defines and runs the CLI. -func Execute() { +// Generates the root command. +func GenerateRootCommand( + inputSource io.Reader, + outputDestination io.Writer, + errorDestination io.Writer, +) *cobra.Command { examples := []string{" $ shamir split -n 3 -t 2", " $ shamir combine -t 2"} // Define root command @@ -19,10 +24,18 @@ func Execute() { Example: strings.Join(examples, "\n"), } + // Set inputs & outputs + rootCommand.SetIn(inputSource) + rootCommand.SetOut(outputDestination) + rootCommand.SetErr(errorDestination) + // Define commands - rootCommand.AddCommand(generateSplitCommand()) - rootCommand.AddCommand(generateCombineCommand()) + rootCommand.AddCommand( + generateSplitCommand(inputSource, outputDestination, errorDestination), + ) + rootCommand.AddCommand( + generateCombineCommand(inputSource, outputDestination, errorDestination), + ) - // Run CLI - cobra.CheckErr(rootCommand.Execute()) + return rootCommand } diff --git a/cmd/split.go b/cmd/split.go index 9828631..56a460a 100644 --- a/cmd/split.go +++ b/cmd/split.go @@ -2,6 +2,8 @@ package cmd import ( "fmt" + "io" + "strings" "github.com/manifoldco/promptui" "github.com/spf13/cobra" @@ -11,7 +13,11 @@ import ( ) // Generates the split command. -func generateSplitCommand() *cobra.Command { +func generateSplitCommand( + inputSource io.Reader, + outputDestination io.Writer, + errorDestination io.Writer, +) *cobra.Command { // Declare command flag values var sharesCount int var thresholdCount int @@ -24,7 +30,13 @@ func generateSplitCommand() *cobra.Command { thereof (of length t) is necessary to reconstruct the original secret.`, Args: cobra.NoArgs, - Run: runSplitCommand(&sharesCount, &thresholdCount), + Run: runSplitCommand( + inputSource, + outputDestination, + errorDestination, + &sharesCount, + &thresholdCount, + ), } // Define command flags @@ -52,14 +64,19 @@ original secret.`, // Runs the split command. func runSplitCommand( + inputSource io.Reader, + outputDestination io.Writer, + errorDestination io.Writer, sharesCount *int, thresholdCount *int, ) func(cmd *cobra.Command, args []string) { return func(cmd *cobra.Command, args []string) { // Define secret prompt prompt := promptui.Prompt{ - Label: "Secret", - Mask: '*', + Stdin: utils.NopReadCloser(inputSource), + Stdout: utils.NopWriteCloser(outputDestination), + Label: "Secret", + Mask: '*', Validate: func(input string) error { if len(input) == 0 { return fmt.Errorf("secret must not be empty") @@ -72,7 +89,7 @@ func runSplitCommand( // Prompt user for secret secret, err := prompt.Run() if err != nil { - utils.ExitWithError(err.Error()) + utils.ExitWithError(errorDestination, err) } // Split secret into shares @@ -82,12 +99,14 @@ func runSplitCommand( *thresholdCount, ) if err != nil { - utils.ExitWithError(err.Error()) + utils.ExitWithError(errorDestination, err) } // Print shares - for _, share := range shares { - fmt.Println(share) + sharesConcatenated := strings.Join(shares, "\n") + _, err = fmt.Fprintln(outputDestination, sharesConcatenated) + if err != nil { + utils.ExitWithError(errorDestination, err) } } } diff --git a/main.go b/main.go index 1c31830..b44a65e 100644 --- a/main.go +++ b/main.go @@ -1,7 +1,27 @@ package main -import "incipher.io/shamir/cmd" +import ( + "os" + + "incipher.io/shamir/cmd" + "incipher.io/shamir/utils" +) func main() { - cmd.Execute() + inputSource := os.Stdin + outputDestination := os.Stdout + errorDestination := os.Stderr + + // Generate root command + rootCommand := cmd.GenerateRootCommand( + inputSource, + outputDestination, + errorDestination, + ) + + // Run root command + err := rootCommand.Execute() + if err != nil { + utils.ExitWithError(errorDestination, err) + } } diff --git a/utils/bufio.go b/utils/bufio.go new file mode 100644 index 0000000..2175afb --- /dev/null +++ b/utils/bufio.go @@ -0,0 +1,24 @@ +package utils + +import ( + "io" +) + +// Returns an io.ReadCloser with a no-op Close method wrapping the provided reader. +func NopReadCloser(reader io.Reader) io.ReadCloser { + return io.NopCloser(reader) +} + +// Returns an io.WriteCloser with a no-op Close method wrapping the provided writer. +func NopWriteCloser(writer io.Writer) io.WriteCloser { + return &WriteCloser{Writer: writer} +} + +func (writeCloser *WriteCloser) Close() error { + // Noop + return nil +} + +type WriteCloser struct { + io.Writer +} diff --git a/utils/os.go b/utils/os.go index aba3dff..af3f9cf 100644 --- a/utils/os.go +++ b/utils/os.go @@ -2,11 +2,12 @@ package utils import ( "fmt" + "io" "os" ) // Prints to stderr and exits with an error code. -func ExitWithError(err string) { - fmt.Fprintln(os.Stderr, err) +func ExitWithError(errorDestination io.Writer, err error) { + fmt.Fprintln(errorDestination, err.Error()) os.Exit(1) }