Skip to content

Commit

Permalink
contexts (gliderlabs#29)
Browse files Browse the repository at this point in the history
* context: working mostly tested context implementation and refactoring to go with it
* _example/ssh-publickey: updating new context based callbacks
* godocs related to public api changes for contexts
* context: converting []bytes to strings before putting into context

Signed-off-by: Jeff Lindsay <progrium@gmail.com>
  • Loading branch information
progrium authored Mar 14, 2017
1 parent 791cd4b commit 9b56478
Show file tree
Hide file tree
Showing 11 changed files with 343 additions and 73 deletions.
2 changes: 1 addition & 1 deletion _example/ssh-publickey/public_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func main() {
s.Write(authorizedKey)
})

publicKeyOption := ssh.PublicKeyAuth(func(user string, key ssh.PublicKey) bool {
publicKeyOption := ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool {
return true // allow all keys, or use ssh.KeysEqual() to compare against known keys
})

Expand Down
142 changes: 142 additions & 0 deletions context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
package ssh

import (
"context"
"net"

gossh "golang.org/x/crypto/ssh"
)

// contextKey is a value for use with context.WithValue. It's used as
// a pointer so it fits in an interface{} without allocation.
type contextKey struct {
name string
}

var (
// ContextKeyUser is a context key for use with Contexts in this package.
// The associated value will be of type string.
ContextKeyUser = &contextKey{"user"}

// ContextKeySessionID is a context key for use with Contexts in this package.
// The associated value will be of type string.
ContextKeySessionID = &contextKey{"session-id"}

// ContextKeyPermissions is a context key for use with Contexts in this package.
// The associated value will be of type *Permissions.
ContextKeyPermissions = &contextKey{"permissions"}

// ContextKeyClientVersion is a context key for use with Contexts in this package.
// The associated value will be of type string.
ContextKeyClientVersion = &contextKey{"client-version"}

// ContextKeyServerVersion is a context key for use with Contexts in this package.
// The associated value will be of type string.
ContextKeyServerVersion = &contextKey{"server-version"}

// ContextKeyLocalAddr is a context key for use with Contexts in this package.
// The associated value will be of type net.Addr.
ContextKeyLocalAddr = &contextKey{"local-addr"}

// ContextKeyRemoteAddr is a context key for use with Contexts in this package.
// The associated value will be of type net.Addr.
ContextKeyRemoteAddr = &contextKey{"remote-addr"}

// ContextKeyServer is a context key for use with Contexts in this package.
// The associated value will be of type *Server.
ContextKeyServer = &contextKey{"ssh-server"}

// ContextKeyPublicKey is a context key for use with Contexts in this package.
// The associated value will be of type PublicKey.
ContextKeyPublicKey = &contextKey{"public-key"}
)

// Context is a package specific context interface. It exposes connection
// metadata and allows new values to be easily written to it. It's used in
// authentication handlers and callbacks, and its underlying context.Context is
// exposed on Session in the session Handler.
type Context interface {
context.Context

// User returns the username used when establishing the SSH connection.
User() string

// SessionID returns the session hash.
SessionID() string

// ClientVersion returns the version reported by the client.
ClientVersion() string

// ServerVersion returns the version reported by the server.
ServerVersion() string

// RemoteAddr returns the remote address for this connection.
RemoteAddr() net.Addr

// LocalAddr returns the local address for this connection.
LocalAddr() net.Addr

// Permissions returns the Permissions object used for this connection.
Permissions() *Permissions

// SetValue allows you to easily write new values into the underlying context.
SetValue(key, value interface{})
}

type sshContext struct {
context.Context
}

func newContext(srv *Server) *sshContext {
ctx := &sshContext{context.Background()}
ctx.SetValue(ContextKeyServer, srv)
perms := &Permissions{&gossh.Permissions{}}
ctx.SetValue(ContextKeyPermissions, perms)
return ctx
}

// this is separate from newContext because we will get ConnMetadata
// at different points so it needs to be applied separately
func (ctx *sshContext) applyConnMetadata(conn gossh.ConnMetadata) {
if ctx.Value(ContextKeySessionID) != nil {
return
}
ctx.SetValue(ContextKeySessionID, string(conn.SessionID()))
ctx.SetValue(ContextKeyClientVersion, string(conn.ClientVersion()))
ctx.SetValue(ContextKeyServerVersion, string(conn.ServerVersion()))
ctx.SetValue(ContextKeyUser, conn.User())
ctx.SetValue(ContextKeyLocalAddr, conn.LocalAddr())
ctx.SetValue(ContextKeyRemoteAddr, conn.RemoteAddr())
}

func (ctx *sshContext) SetValue(key, value interface{}) {
ctx.Context = context.WithValue(ctx.Context, key, value)
}

func (ctx *sshContext) User() string {
return ctx.Value(ContextKeyUser).(string)
}

func (ctx *sshContext) SessionID() string {
return ctx.Value(ContextKeySessionID).(string)
}

func (ctx *sshContext) ClientVersion() string {
return ctx.Value(ContextKeyClientVersion).(string)
}

func (ctx *sshContext) ServerVersion() string {
return ctx.Value(ContextKeyServerVersion).(string)
}

func (ctx *sshContext) RemoteAddr() net.Addr {
return ctx.Value(ContextKeyRemoteAddr).(net.Addr)
}

func (ctx *sshContext) LocalAddr() net.Addr {
return ctx.Value(ContextKeyLocalAddr).(net.Addr)
}

func (ctx *sshContext) Permissions() *Permissions {
return ctx.Value(ContextKeyPermissions).(*Permissions)
}
47 changes: 47 additions & 0 deletions context_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package ssh

import "testing"

func TestSetPermissions(t *testing.T) {
t.Parallel()
permsExt := map[string]string{
"foo": "bar",
}
session, cleanup := newTestSessionWithOptions(t, &Server{
Handler: func(s Session) {
if _, ok := s.Permissions().Extensions["foo"]; !ok {
t.Fatalf("got %#v; want %#v", s.Permissions().Extensions, permsExt)
}
},
}, nil, PasswordAuth(func(ctx Context, password string) bool {
ctx.Permissions().Extensions = permsExt
return true
}))
defer cleanup()
if err := session.Run(""); err != nil {
t.Fatal(err)
}
}

func TestSetValue(t *testing.T) {
t.Parallel()
value := map[string]string{
"foo": "bar",
}
key := "testValue"
session, cleanup := newTestSessionWithOptions(t, &Server{
Handler: func(s Session) {
v := s.Context().Value(key).(map[string]string)
if v["foo"] != value["foo"] {
t.Fatalf("got %#v; want %#v", v, value)
}
},
}, nil, PasswordAuth(func(ctx Context, password string) bool {
ctx.SetValue(key, value)
return true
}))
defer cleanup()
if err := session.Run(""); err != nil {
t.Fatal(err)
}
}
4 changes: 2 additions & 2 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func ExampleListenAndServe() {

func ExamplePasswordAuth() {
ssh.ListenAndServe(":2222", nil,
ssh.PasswordAuth(func(user, pass string) bool {
ssh.PasswordAuth(func(ctx ssh.Context, pass string) bool {
return pass == "secret"
}),
)
Expand All @@ -27,7 +27,7 @@ func ExampleNoPty() {

func ExamplePublicKeyAuth() {
ssh.ListenAndServe(":2222", nil,
ssh.PublicKeyAuth(func(user string, key ssh.PublicKey) bool {
ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool {
data, _ := ioutil.ReadFile("/path/to/allowed/key.pub")
allowed, _, _, _, _ := ssh.ParseAuthorizedKey(data)
return ssh.KeysEqual(key, allowed)
Expand Down
2 changes: 1 addition & 1 deletion options.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func HostKeyPEM(bytes []byte) Option {
// denying PTY requests.
func NoPty() Option {
return func(srv *Server) error {
srv.PtyCallback = func(user string, permissions *Permissions) bool {
srv.PtyCallback = func(ctx Context, pty Pty) bool {
return false
}
return nil
Expand Down
66 changes: 66 additions & 0 deletions options_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package ssh

import (
"strings"
"testing"

gossh "golang.org/x/crypto/ssh"
)

func newTestSessionWithOptions(t *testing.T, srv *Server, cfg *gossh.ClientConfig, options ...Option) (*gossh.Session, func()) {
for _, option := range options {
if err := srv.SetOption(option); err != nil {
t.Fatal(err)
}
}
return newTestSession(t, srv, cfg)
}

func TestPasswordAuth(t *testing.T) {
t.Parallel()
testUser := "testuser"
testPass := "testpass"
session, cleanup := newTestSessionWithOptions(t, &Server{
Handler: func(s Session) {
// noop
},
}, &gossh.ClientConfig{
User: testUser,
Auth: []gossh.AuthMethod{
gossh.Password(testPass),
},
}, PasswordAuth(func(ctx Context, password string) bool {
if ctx.User() != testUser {
t.Fatalf("user = %#v; want %#v", ctx.User(), testUser)
}
if password != testPass {
t.Fatalf("user = %#v; want %#v", password, testPass)
}
return true
}))
defer cleanup()
if err := session.Run(""); err != nil {
t.Fatal(err)
}
}

func TestPasswordAuthBadPass(t *testing.T) {
t.Parallel()
l := newLocalListener()
srv := &Server{Handler: func(s Session) {}}
srv.SetOption(PasswordAuth(func(ctx Context, password string) bool {
return false
}))
go srv.serveOnce(l)
_, err := gossh.Dial("tcp", l.Addr().String(), &gossh.ClientConfig{
User: "testuser",
Auth: []gossh.AuthMethod{
gossh.Password("testpass"),
},
})
if err != nil {
if !strings.Contains(err.Error(), "unable to authenticate") {
t.Fatal(err)
}
}
}
Loading

0 comments on commit 9b56478

Please sign in to comment.