Skip to content

Commit

Permalink
🔨 Control input source & output/error destination
Browse files Browse the repository at this point in the history
  • Loading branch information
itsabdelrahman committed Jun 16, 2022
1 parent 454504e commit ec4bb8f
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 25 deletions.
32 changes: 25 additions & 7 deletions cmd/combine.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cmd

import (
"fmt"
"io"

"github.com/manifoldco/promptui"
"github.com/spf13/cobra"
Expand All @@ -11,7 +12,11 @@ import (
)

// Generates the combine command.
func generateCombineCommand() *cobra.Command {
func generateCombineCommand(
inputSource io.Reader,
outputDestination io.Writer,
errorDestination io.Writer,
) *cobra.Command {
// Declare command flag values
var thresholdCount int

Expand All @@ -21,7 +26,12 @@ func generateCombineCommand() *cobra.Command {
Short: "Reconstruct a secret from shares",
Long: "Reconstructs a secret from shares.",
Args: cobra.NoArgs,
Run: runCombineCommand(&thresholdCount),
Run: runCombineCommand(
inputSource,
outputDestination,
errorDestination,
&thresholdCount,
),
}

// Define command flags
Expand All @@ -40,20 +50,25 @@ func generateCombineCommand() *cobra.Command {

// Runs the combine command.
func runCombineCommand(
inputSource io.Reader,
outputDestination io.Writer,
errorDestination io.Writer,
thresholdCount *int,
) func(cmd *cobra.Command, args []string) {
return func(cmd *cobra.Command, args []string) {
// Validate flag values
if *thresholdCount < 2 || *thresholdCount > 255 {
utils.ExitWithError("threshold must be between 2 and 255")
utils.ExitWithError(errorDestination, fmt.Errorf("threshold must be between 2 and 255"))
}

// Prompt user for shares
shares := make([]string, *thresholdCount)

for i := 0; i < *thresholdCount; i++ {
prompt := promptui.Prompt{
Label: fmt.Sprintf("Share #%d", i+1),
Stdin: utils.NopReadCloser(inputSource),
Stdout: utils.NopWriteCloser(outputDestination),
Label: fmt.Sprintf("Share #%d", i+1),
Validate: func(input string) error {
if len(input) == 0 {
return fmt.Errorf("share must not be empty")
Expand All @@ -73,7 +88,7 @@ func runCombineCommand(

share, err := prompt.Run()
if err != nil {
utils.ExitWithError(err.Error())
utils.ExitWithError(errorDestination, err)
}

shares[i] = share
Expand All @@ -82,10 +97,13 @@ func runCombineCommand(
// Reconstruct secret from shares
secret, err := shamir.Combine(shares)
if err != nil {
utils.ExitWithError(err.Error())
utils.ExitWithError(errorDestination, err)
}

// Print secret
fmt.Println(secret)
_, err = fmt.Fprintln(outputDestination, secret)
if err != nil {
utils.ExitWithError(errorDestination, err)
}
}
}
25 changes: 19 additions & 6 deletions cmd/root.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
package cmd

import (
"io"
"strings"

"github.com/spf13/cobra"
)

// Defines and runs the CLI.
func Execute() {
// Generates the root command.
func GenerateRootCommand(
inputSource io.Reader,
outputDestination io.Writer,
errorDestination io.Writer,
) *cobra.Command {
examples := []string{" $ shamir split -n 3 -t 2", " $ shamir combine -t 2"}

// Define root command
Expand All @@ -19,10 +24,18 @@ func Execute() {
Example: strings.Join(examples, "\n"),
}

// Set inputs & outputs
rootCommand.SetIn(inputSource)
rootCommand.SetOut(outputDestination)
rootCommand.SetErr(errorDestination)

// Define commands
rootCommand.AddCommand(generateSplitCommand())
rootCommand.AddCommand(generateCombineCommand())
rootCommand.AddCommand(
generateSplitCommand(inputSource, outputDestination, errorDestination),
)
rootCommand.AddCommand(
generateCombineCommand(inputSource, outputDestination, errorDestination),
)

// Run CLI
cobra.CheckErr(rootCommand.Execute())
return rootCommand
}
35 changes: 27 additions & 8 deletions cmd/split.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package cmd

import (
"fmt"
"io"
"strings"

"github.com/manifoldco/promptui"
"github.com/spf13/cobra"
Expand All @@ -11,7 +13,11 @@ import (
)

// Generates the split command.
func generateSplitCommand() *cobra.Command {
func generateSplitCommand(
inputSource io.Reader,
outputDestination io.Writer,
errorDestination io.Writer,
) *cobra.Command {
// Declare command flag values
var sharesCount int
var thresholdCount int
Expand All @@ -24,7 +30,13 @@ func generateSplitCommand() *cobra.Command {
thereof (of length t) is necessary to reconstruct the
original secret.`,
Args: cobra.NoArgs,
Run: runSplitCommand(&sharesCount, &thresholdCount),
Run: runSplitCommand(
inputSource,
outputDestination,
errorDestination,
&sharesCount,
&thresholdCount,
),
}

// Define command flags
Expand Down Expand Up @@ -52,14 +64,19 @@ original secret.`,

// Runs the split command.
func runSplitCommand(
inputSource io.Reader,
outputDestination io.Writer,
errorDestination io.Writer,
sharesCount *int,
thresholdCount *int,
) func(cmd *cobra.Command, args []string) {
return func(cmd *cobra.Command, args []string) {
// Define secret prompt
prompt := promptui.Prompt{
Label: "Secret",
Mask: '*',
Stdin: utils.NopReadCloser(inputSource),
Stdout: utils.NopWriteCloser(outputDestination),
Label: "Secret",
Mask: '*',
Validate: func(input string) error {
if len(input) == 0 {
return fmt.Errorf("secret must not be empty")
Expand All @@ -72,7 +89,7 @@ func runSplitCommand(
// Prompt user for secret
secret, err := prompt.Run()
if err != nil {
utils.ExitWithError(err.Error())
utils.ExitWithError(errorDestination, err)
}

// Split secret into shares
Expand All @@ -82,12 +99,14 @@ func runSplitCommand(
*thresholdCount,
)
if err != nil {
utils.ExitWithError(err.Error())
utils.ExitWithError(errorDestination, err)
}

// Print shares
for _, share := range shares {
fmt.Println(share)
sharesConcatenated := strings.Join(shares, "\n")
_, err = fmt.Fprintln(outputDestination, sharesConcatenated)
if err != nil {
utils.ExitWithError(errorDestination, err)
}
}
}
24 changes: 22 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,27 @@
package main

import "incipher.io/shamir/cmd"
import (
"os"

"incipher.io/shamir/cmd"
"incipher.io/shamir/utils"
)

func main() {
cmd.Execute()
inputSource := os.Stdin
outputDestination := os.Stdout
errorDestination := os.Stderr

// Generate root command
rootCommand := cmd.GenerateRootCommand(
inputSource,
outputDestination,
errorDestination,
)

// Run root command
err := rootCommand.Execute()
if err != nil {
utils.ExitWithError(errorDestination, err)
}
}
24 changes: 24 additions & 0 deletions utils/bufio.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package utils

import (
"io"
)

// Returns an io.ReadCloser with a no-op Close method wrapping the provided reader.
func NopReadCloser(reader io.Reader) io.ReadCloser {
return io.NopCloser(reader)
}

// Returns an io.WriteCloser with a no-op Close method wrapping the provided writer.
func NopWriteCloser(writer io.Writer) io.WriteCloser {
return &WriteCloser{Writer: writer}
}

func (writeCloser *WriteCloser) Close() error {
// Noop
return nil
}

type WriteCloser struct {
io.Writer
}
5 changes: 3 additions & 2 deletions utils/os.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ package utils

import (
"fmt"
"io"
"os"
)

// Prints to stderr and exits with an error code.
func ExitWithError(err string) {
fmt.Fprintln(os.Stderr, err)
func ExitWithError(errorDestination io.Writer, err error) {
fmt.Fprintln(errorDestination, err.Error())
os.Exit(1)
}

0 comments on commit ec4bb8f

Please sign in to comment.