Skip to content

Commit

Permalink
age: remove Type method from Recipient and Identity interfaces
Browse files Browse the repository at this point in the history
The Type() method was a mistake, as proven by the fact that I can remove
it without losing any functionality. It gives special meaning to the
"0th argument" of recipient stanzas, when actually it should be left up
to Recipient implementations to make their own stanzas recognizable to
their Identity counterparts.

More importantly, there are totally reasonable Identity (and probably
Recipient) implementations that don't know their own stanza type in
advance. For example, a proxy plugin.

Concretely, it was only used to special-case "scrypt" recipients, and to
skip invoking Unwrap. The former can be done based on the returned
recipient stanza, and the latter is best avoided entirely: the Identity
should start by looking at the stanza and returning ErrIncorrectIdentity
if it's of the wrong type.

This is a breaking API change. However, we are still in beta, and none
of the public downstreams look like they would be affected, as they only
use Recipient and Identity implementations from this package, they only
use them with the interfaces defined in this package, and they don't
directly use the Type() method.
  • Loading branch information
FiloSottile committed Feb 8, 2021
1 parent 15df6e2 commit 6546df3
Show file tree
Hide file tree
Showing 8 changed files with 17 additions and 58 deletions.
18 changes: 6 additions & 12 deletions age.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,9 @@ import (
// An Identity is a private key or other value that can decrypt an opaque file
// key from a recipient stanza.
//
// Unwrap must return ErrIncorrectIdentity for recipient blocks that don't match
// the identity, any other error might be considered fatal.
// Unwrap must return ErrIncorrectIdentity for recipient stanzas that don't
// match the identity, any other error might be considered fatal.
type Identity interface {
Type() string
Unwrap(block *Stanza) (fileKey []byte, err error)
}

Expand All @@ -75,7 +74,6 @@ var ErrIncorrectIdentity = errors.New("incorrect identity for recipient block")
// A Recipient is a public key or other value that can encrypt an opaque file
// key to a recipient stanza.
type Recipient interface {
Type() string
Wrap(fileKey []byte) (*Stanza, error)
}

Expand Down Expand Up @@ -109,15 +107,15 @@ func Encrypt(dst io.Writer, recipients ...Recipient) (io.WriteCloser, error) {

hdr := &format.Header{}
for i, r := range recipients {
if r.Type() == "scrypt" && len(recipients) != 1 {
return nil, errors.New("an scrypt recipient must be the only one")
}

block, err := r.Wrap(fileKey)
if err != nil {
return nil, fmt.Errorf("failed to wrap key for recipient #%d: %v", i, err)
}
hdr.Recipients = append(hdr.Recipients, (*format.Stanza)(block))

if block.Type == "scrypt" && len(recipients) != 1 {
return nil, errors.New("an scrypt recipient must be the only one")
}
}
if mac, err := headerMAC(fileKey, hdr); err != nil {
return nil, fmt.Errorf("failed to compute header MAC: %v", err)
Expand Down Expand Up @@ -163,10 +161,6 @@ RecipientsLoop:
return nil, errors.New("an scrypt recipient must be the only one")
}
for _, i := range identities {
if i.Type() != r.Type {
continue
}

if i, ok := i.(IdentityMatcher); ok {
err := i.Match((*Stanza)(r))
if err != nil {
Expand Down
8 changes: 0 additions & 8 deletions agessh/agessh.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ type RSARecipient struct {

var _ age.Recipient = &RSARecipient{}

func (*RSARecipient) Type() string { return "ssh-rsa" }

func NewRSARecipient(pk ssh.PublicKey) (*RSARecipient, error) {
if pk.Type() != "ssh-rsa" {
return nil, errors.New("SSH public key is not an RSA key")
Expand Down Expand Up @@ -93,8 +91,6 @@ type RSAIdentity struct {

var _ age.Identity = &RSAIdentity{}

func (*RSAIdentity) Type() string { return "ssh-rsa" }

func NewRSAIdentity(key *rsa.PrivateKey) (*RSAIdentity, error) {
s, err := ssh.NewSignerFromKey(key)
if err != nil {
Expand Down Expand Up @@ -133,8 +129,6 @@ type Ed25519Recipient struct {

var _ age.Recipient = &Ed25519Recipient{}

func (*Ed25519Recipient) Type() string { return "ssh-ed25519" }

func NewEd25519Recipient(pk ssh.PublicKey) (*Ed25519Recipient, error) {
if pk.Type() != "ssh-ed25519" {
return nil, errors.New("SSH public key is not an Ed25519 key")
Expand Down Expand Up @@ -246,8 +240,6 @@ type Ed25519Identity struct {

var _ age.Identity = &Ed25519Identity{}

func (*Ed25519Identity) Type() string { return "ssh-ed25519" }

func NewEd25519Identity(key ed25519.PrivateKey) (*Ed25519Identity, error) {
s, err := ssh.NewSignerFromKey(key)
if err != nil {
Expand Down
8 changes: 0 additions & 8 deletions agessh/agessh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,6 @@ func TestSSHRSARoundTrip(t *testing.T) {
t.Fatal(err)
}

if r.Type() != i.Type() || r.Type() != "ssh-rsa" {
t.Errorf("invalid Type values: %v, %v", r.Type(), i.Type())
}

fileKey := make([]byte, 16)
if _, err := rand.Read(fileKey); err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -82,10 +78,6 @@ func TestSSHEd25519RoundTrip(t *testing.T) {
t.Fatal(err)
}

if r.Type() != i.Type() || r.Type() != "ssh-ed25519" {
t.Errorf("invalid Type values: %v, %v", r.Type(), i.Type())
}

fileKey := make([]byte, 16)
if _, err := rand.Read(fileKey); err != nil {
t.Fatal(err)
Expand Down
18 changes: 8 additions & 10 deletions agessh/encrypted_keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,6 @@ func NewEncryptedSSHIdentity(pubKey ssh.PublicKey, pemBytes []byte, passphrase f

var _ age.IdentityMatcher = &EncryptedSSHIdentity{}

// Type returns the type of the underlying private key, "ssh-ed25519" or "ssh-rsa".
func (i *EncryptedSSHIdentity) Type() string {
return i.pubKey.Type()
}

// Unwrap implements age.Identity. If the private key is still encrypted, it
// will request the passphrase. The decrypted private key will be cached after
// the first successful invocation.
Expand All @@ -81,29 +76,32 @@ func (i *EncryptedSSHIdentity) Unwrap(block *age.Stanza) (fileKey []byte, err er
switch k := k.(type) {
case *ed25519.PrivateKey:
i.decrypted, err = NewEd25519Identity(*k)
if i.pubKey.Type() != ssh.KeyAlgoED25519 {
return nil, fmt.Errorf("mismatched SSH key type: got %q, expected %q", ssh.KeyAlgoED25519, i.pubKey.Type())
}
case *rsa.PrivateKey:
i.decrypted, err = NewRSAIdentity(k)
if i.pubKey.Type() != ssh.KeyAlgoRSA {
return nil, fmt.Errorf("mismatched SSH key type: got %q, expected %q", ssh.KeyAlgoRSA, i.pubKey.Type())
}
default:
return nil, fmt.Errorf("unexpected SSH key type: %T", k)
}
if err != nil {
return nil, fmt.Errorf("invalid SSH key: %v", err)
}
if i.decrypted.Type() != i.pubKey.Type() {
return nil, fmt.Errorf("mismatched SSH key type: got %q, expected %q", i.decrypted.Type(), i.pubKey.Type())
}

return i.decrypted.Unwrap(block)
}

// Match implements age.IdentityMatcher without decrypting the private key, to
// ensure the passphrase is only obtained if necessary.
func (i *EncryptedSSHIdentity) Match(block *age.Stanza) error {
if block.Type != i.Type() {
if block.Type != i.pubKey.Type() {
return age.ErrIncorrectIdentity
}
if len(block.Args) < 1 {
return fmt.Errorf("invalid %v recipient block", i.Type())
return fmt.Errorf("invalid %v recipient block", i.pubKey.Type())
}

if block.Args[0] != sshFingerprint(i.pubKey) {
Expand Down
7 changes: 3 additions & 4 deletions cmd/age/encrypted_keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@ type LazyScryptIdentity struct {

var _ age.Identity = &LazyScryptIdentity{}

func (i *LazyScryptIdentity) Type() string {
return "scrypt"
}

func (i *LazyScryptIdentity) Unwrap(block *age.Stanza) (fileKey []byte, err error) {
if block.Type != "scrypt" {
return nil, age.ErrIncorrectIdentity
}
pass, err := i.Passphrase()
if err != nil {
return nil, fmt.Errorf("could not read passphrase: %v", err)
Expand Down
8 changes: 0 additions & 8 deletions recipients_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@ func TestX25519RoundTrip(t *testing.T) {
}
r := i.Recipient()

if r.Type() != i.Type() || r.Type() != "X25519" {
t.Errorf("invalid Type values: %v, %v", r.Type(), i.Type())
}

if r1, err := age.ParseX25519Recipient(r.String()); err != nil {
t.Fatal(err)
} else if r1.String() != r.String() {
Expand Down Expand Up @@ -72,10 +68,6 @@ func TestScryptRoundTrip(t *testing.T) {
t.Fatal(err)
}

if r.Type() != i.Type() || r.Type() != "scrypt" {
t.Errorf("invalid Type values: %v, %v", r.Type(), i.Type())
}

fileKey := make([]byte, 16)
if _, err := rand.Read(fileKey); err != nil {
t.Fatal(err)
Expand Down
4 changes: 0 additions & 4 deletions scrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ type ScryptRecipient struct {

var _ Recipient = &ScryptRecipient{}

func (*ScryptRecipient) Type() string { return "scrypt" }

// NewScryptRecipient returns a new ScryptRecipient with the provided password.
func NewScryptRecipient(password string) (*ScryptRecipient, error) {
if len(password) == 0 {
Expand Down Expand Up @@ -98,8 +96,6 @@ type ScryptIdentity struct {

var _ Identity = &ScryptIdentity{}

func (*ScryptIdentity) Type() string { return "scrypt" }

// NewScryptIdentity returns a new ScryptIdentity with the provided password.
func NewScryptIdentity(password string) (*ScryptIdentity, error) {
if len(password) == 0 {
Expand Down
4 changes: 0 additions & 4 deletions x25519.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ type X25519Recipient struct {

var _ Recipient = &X25519Recipient{}

func (*X25519Recipient) Type() string { return "X25519" }

// newX25519RecipientFromPoint returns a new X25519Recipient from a raw Curve25519 point.
func newX25519RecipientFromPoint(publicKey []byte) (*X25519Recipient, error) {
if len(publicKey) != curve25519.PointSize {
Expand Down Expand Up @@ -117,8 +115,6 @@ type X25519Identity struct {

var _ Identity = &X25519Identity{}

func (*X25519Identity) Type() string { return "X25519" }

// newX25519IdentityFromScalar returns a new X25519Identity from a raw Curve25519 scalar.
func newX25519IdentityFromScalar(secretKey []byte) (*X25519Identity, error) {
if len(secretKey) != curve25519.ScalarSize {
Expand Down

0 comments on commit 6546df3

Please sign in to comment.