diff --git a/_example/ssh-publickey/public_key.go b/_example/ssh-publickey/public_key.go index 215292e..453ce04 100644 --- a/_example/ssh-publickey/public_key.go +++ b/_example/ssh-publickey/public_key.go @@ -16,7 +16,7 @@ func main() { s.Write(authorizedKey) }) - publicKeyOption := ssh.PublicKeyAuth(func(user string, key ssh.PublicKey) bool { + publicKeyOption := ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool { return true // allow all keys, or use ssh.KeysEqual() to compare against known keys }) diff --git a/context.go b/context.go new file mode 100644 index 0000000..f666d3a --- /dev/null +++ b/context.go @@ -0,0 +1,142 @@ +package ssh + +import ( + "context" + "net" + + gossh "golang.org/x/crypto/ssh" +) + +// contextKey is a value for use with context.WithValue. It's used as +// a pointer so it fits in an interface{} without allocation. +type contextKey struct { + name string +} + +var ( + // ContextKeyUser is a context key for use with Contexts in this package. + // The associated value will be of type string. + ContextKeyUser = &contextKey{"user"} + + // ContextKeySessionID is a context key for use with Contexts in this package. + // The associated value will be of type string. + ContextKeySessionID = &contextKey{"session-id"} + + // ContextKeyPermissions is a context key for use with Contexts in this package. + // The associated value will be of type *Permissions. + ContextKeyPermissions = &contextKey{"permissions"} + + // ContextKeyClientVersion is a context key for use with Contexts in this package. + // The associated value will be of type string. + ContextKeyClientVersion = &contextKey{"client-version"} + + // ContextKeyServerVersion is a context key for use with Contexts in this package. + // The associated value will be of type string. + ContextKeyServerVersion = &contextKey{"server-version"} + + // ContextKeyLocalAddr is a context key for use with Contexts in this package. + // The associated value will be of type net.Addr. + ContextKeyLocalAddr = &contextKey{"local-addr"} + + // ContextKeyRemoteAddr is a context key for use with Contexts in this package. + // The associated value will be of type net.Addr. + ContextKeyRemoteAddr = &contextKey{"remote-addr"} + + // ContextKeyServer is a context key for use with Contexts in this package. + // The associated value will be of type *Server. + ContextKeyServer = &contextKey{"ssh-server"} + + // ContextKeyPublicKey is a context key for use with Contexts in this package. + // The associated value will be of type PublicKey. + ContextKeyPublicKey = &contextKey{"public-key"} +) + +// Context is a package specific context interface. It exposes connection +// metadata and allows new values to be easily written to it. It's used in +// authentication handlers and callbacks, and its underlying context.Context is +// exposed on Session in the session Handler. +type Context interface { + context.Context + + // User returns the username used when establishing the SSH connection. + User() string + + // SessionID returns the session hash. + SessionID() string + + // ClientVersion returns the version reported by the client. + ClientVersion() string + + // ServerVersion returns the version reported by the server. + ServerVersion() string + + // RemoteAddr returns the remote address for this connection. + RemoteAddr() net.Addr + + // LocalAddr returns the local address for this connection. + LocalAddr() net.Addr + + // Permissions returns the Permissions object used for this connection. + Permissions() *Permissions + + // SetValue allows you to easily write new values into the underlying context. + SetValue(key, value interface{}) +} + +type sshContext struct { + context.Context +} + +func newContext(srv *Server) *sshContext { + ctx := &sshContext{context.Background()} + ctx.SetValue(ContextKeyServer, srv) + perms := &Permissions{&gossh.Permissions{}} + ctx.SetValue(ContextKeyPermissions, perms) + return ctx +} + +// this is separate from newContext because we will get ConnMetadata +// at different points so it needs to be applied separately +func (ctx *sshContext) applyConnMetadata(conn gossh.ConnMetadata) { + if ctx.Value(ContextKeySessionID) != nil { + return + } + ctx.SetValue(ContextKeySessionID, string(conn.SessionID())) + ctx.SetValue(ContextKeyClientVersion, string(conn.ClientVersion())) + ctx.SetValue(ContextKeyServerVersion, string(conn.ServerVersion())) + ctx.SetValue(ContextKeyUser, conn.User()) + ctx.SetValue(ContextKeyLocalAddr, conn.LocalAddr()) + ctx.SetValue(ContextKeyRemoteAddr, conn.RemoteAddr()) +} + +func (ctx *sshContext) SetValue(key, value interface{}) { + ctx.Context = context.WithValue(ctx.Context, key, value) +} + +func (ctx *sshContext) User() string { + return ctx.Value(ContextKeyUser).(string) +} + +func (ctx *sshContext) SessionID() string { + return ctx.Value(ContextKeySessionID).(string) +} + +func (ctx *sshContext) ClientVersion() string { + return ctx.Value(ContextKeyClientVersion).(string) +} + +func (ctx *sshContext) ServerVersion() string { + return ctx.Value(ContextKeyServerVersion).(string) +} + +func (ctx *sshContext) RemoteAddr() net.Addr { + return ctx.Value(ContextKeyRemoteAddr).(net.Addr) +} + +func (ctx *sshContext) LocalAddr() net.Addr { + return ctx.Value(ContextKeyLocalAddr).(net.Addr) +} + +func (ctx *sshContext) Permissions() *Permissions { + return ctx.Value(ContextKeyPermissions).(*Permissions) +} diff --git a/context_test.go b/context_test.go new file mode 100644 index 0000000..f7be4d9 --- /dev/null +++ b/context_test.go @@ -0,0 +1,47 @@ +package ssh + +import "testing" + +func TestSetPermissions(t *testing.T) { + t.Parallel() + permsExt := map[string]string{ + "foo": "bar", + } + session, cleanup := newTestSessionWithOptions(t, &Server{ + Handler: func(s Session) { + if _, ok := s.Permissions().Extensions["foo"]; !ok { + t.Fatalf("got %#v; want %#v", s.Permissions().Extensions, permsExt) + } + }, + }, nil, PasswordAuth(func(ctx Context, password string) bool { + ctx.Permissions().Extensions = permsExt + return true + })) + defer cleanup() + if err := session.Run(""); err != nil { + t.Fatal(err) + } +} + +func TestSetValue(t *testing.T) { + t.Parallel() + value := map[string]string{ + "foo": "bar", + } + key := "testValue" + session, cleanup := newTestSessionWithOptions(t, &Server{ + Handler: func(s Session) { + v := s.Context().Value(key).(map[string]string) + if v["foo"] != value["foo"] { + t.Fatalf("got %#v; want %#v", v, value) + } + }, + }, nil, PasswordAuth(func(ctx Context, password string) bool { + ctx.SetValue(key, value) + return true + })) + defer cleanup() + if err := session.Run(""); err != nil { + t.Fatal(err) + } +} diff --git a/example_test.go b/example_test.go index 38707d4..972d3ef 100644 --- a/example_test.go +++ b/example_test.go @@ -15,7 +15,7 @@ func ExampleListenAndServe() { func ExamplePasswordAuth() { ssh.ListenAndServe(":2222", nil, - ssh.PasswordAuth(func(user, pass string) bool { + ssh.PasswordAuth(func(ctx ssh.Context, pass string) bool { return pass == "secret" }), ) @@ -27,7 +27,7 @@ func ExampleNoPty() { func ExamplePublicKeyAuth() { ssh.ListenAndServe(":2222", nil, - ssh.PublicKeyAuth(func(user string, key ssh.PublicKey) bool { + ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool { data, _ := ioutil.ReadFile("/path/to/allowed/key.pub") allowed, _, _, _, _ := ssh.ParseAuthorizedKey(data) return ssh.KeysEqual(key, allowed) diff --git a/options.go b/options.go index 3174972..cd05823 100644 --- a/options.go +++ b/options.go @@ -56,7 +56,7 @@ func HostKeyPEM(bytes []byte) Option { // denying PTY requests. func NoPty() Option { return func(srv *Server) error { - srv.PtyCallback = func(user string, permissions *Permissions) bool { + srv.PtyCallback = func(ctx Context, pty Pty) bool { return false } return nil diff --git a/options_test.go b/options_test.go new file mode 100644 index 0000000..d6ea4b5 --- /dev/null +++ b/options_test.go @@ -0,0 +1,66 @@ +package ssh + +import ( + "strings" + "testing" + + gossh "golang.org/x/crypto/ssh" +) + +func newTestSessionWithOptions(t *testing.T, srv *Server, cfg *gossh.ClientConfig, options ...Option) (*gossh.Session, func()) { + for _, option := range options { + if err := srv.SetOption(option); err != nil { + t.Fatal(err) + } + } + return newTestSession(t, srv, cfg) +} + +func TestPasswordAuth(t *testing.T) { + t.Parallel() + testUser := "testuser" + testPass := "testpass" + session, cleanup := newTestSessionWithOptions(t, &Server{ + Handler: func(s Session) { + // noop + }, + }, &gossh.ClientConfig{ + User: testUser, + Auth: []gossh.AuthMethod{ + gossh.Password(testPass), + }, + }, PasswordAuth(func(ctx Context, password string) bool { + if ctx.User() != testUser { + t.Fatalf("user = %#v; want %#v", ctx.User(), testUser) + } + if password != testPass { + t.Fatalf("user = %#v; want %#v", password, testPass) + } + return true + })) + defer cleanup() + if err := session.Run(""); err != nil { + t.Fatal(err) + } +} + +func TestPasswordAuthBadPass(t *testing.T) { + t.Parallel() + l := newLocalListener() + srv := &Server{Handler: func(s Session) {}} + srv.SetOption(PasswordAuth(func(ctx Context, password string) bool { + return false + })) + go srv.serveOnce(l) + _, err := gossh.Dial("tcp", l.Addr().String(), &gossh.ClientConfig{ + User: "testuser", + Auth: []gossh.AuthMethod{ + gossh.Password("testpass"), + }, + }) + if err != nil { + if !strings.Contains(err.Error(), "unable to authenticate") { + t.Fatal(err) + } + } +} diff --git a/server.go b/server.go index 9ea48ec..27ea7cb 100644 --- a/server.go +++ b/server.go @@ -17,21 +17,24 @@ type Server struct { HostSigners []Signer // private keys for the host key, must have at least one Version string // server version to be sent before the initial handshake - PasswordHandler PasswordHandler // password authentication handler - PublicKeyHandler PublicKeyHandler // public key authentication handler - PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil - PermissionsCallback PermissionsCallback // optional callback for setting up permissions + PasswordHandler PasswordHandler // password authentication handler + PublicKeyHandler PublicKeyHandler // public key authentication handler + PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil } -func (srv *Server) makeConfig() (*gossh.ServerConfig, error) { - config := &gossh.ServerConfig{} +func (srv *Server) ensureHostSigner() error { if len(srv.HostSigners) == 0 { signer, err := generateSigner() if err != nil { - return nil, err + return err } srv.HostSigners = append(srv.HostSigners, signer) } + return nil +} + +func (srv *Server) config(ctx *sshContext) *gossh.ServerConfig { + config := &gossh.ServerConfig{} for _, signer := range srv.HostSigners { config.AddHostKey(signer) } @@ -43,34 +46,24 @@ func (srv *Server) makeConfig() (*gossh.ServerConfig, error) { } if srv.PasswordHandler != nil { config.PasswordCallback = func(conn gossh.ConnMetadata, password []byte) (*gossh.Permissions, error) { - perms := &gossh.Permissions{} - if ok := srv.PasswordHandler(conn.User(), string(password)); !ok { - return perms, fmt.Errorf("permission denied") + ctx.applyConnMetadata(conn) + if ok := srv.PasswordHandler(ctx, string(password)); !ok { + return ctx.Permissions().Permissions, fmt.Errorf("permission denied") } - if srv.PermissionsCallback != nil { - srv.PermissionsCallback(conn.User(), &Permissions{perms}) - } - return perms, nil + return ctx.Permissions().Permissions, nil } } if srv.PublicKeyHandler != nil { config.PublicKeyCallback = func(conn gossh.ConnMetadata, key gossh.PublicKey) (*gossh.Permissions, error) { - perms := &gossh.Permissions{} - if ok := srv.PublicKeyHandler(conn.User(), key); !ok { - return perms, fmt.Errorf("permission denied") - } - // no other way to pass the key from - // auth handler to session handler - perms.Extensions = map[string]string{ - "_publickey": string(key.Marshal()), + ctx.applyConnMetadata(conn) + if ok := srv.PublicKeyHandler(ctx, key); !ok { + return ctx.Permissions().Permissions, fmt.Errorf("permission denied") } - if srv.PermissionsCallback != nil { - srv.PermissionsCallback(conn.User(), &Permissions{perms}) - } - return perms, nil + ctx.SetValue(ContextKeyPublicKey, key) + return ctx.Permissions().Permissions, nil } } - return config, nil + return config } // Handle sets the Handler for the server. @@ -85,8 +78,7 @@ func (srv *Server) Handle(fn Handler) { // Serve always returns a non-nil error. func (srv *Server) Serve(l net.Listener) error { defer l.Close() - config, err := srv.makeConfig() - if err != nil { + if err := srv.ensureHostSigner(); err != nil { return err } if srv.Handler == nil { @@ -110,41 +102,46 @@ func (srv *Server) Serve(l net.Listener) error { } return e } - go srv.handleConn(conn, config) + go srv.handleConn(conn) } } -func (srv *Server) handleConn(conn net.Conn, conf *gossh.ServerConfig) { +func (srv *Server) handleConn(conn net.Conn) { defer conn.Close() - sshConn, chans, reqs, err := gossh.NewServerConn(conn, conf) + ctx := newContext(srv) + sshConn, chans, reqs, err := gossh.NewServerConn(conn, srv.config(ctx)) if err != nil { + // TODO: trigger event callback return } + ctx.applyConnMetadata(sshConn) go gossh.DiscardRequests(reqs) for ch := range chans { if ch.ChannelType() != "session" { ch.Reject(gossh.UnknownChannelType, "unsupported channel type") continue } - go srv.handleChannel(sshConn, ch) + go srv.handleChannel(sshConn, ch, ctx) } } -func (srv *Server) handleChannel(conn *gossh.ServerConn, newChan gossh.NewChannel) { +func (srv *Server) handleChannel(conn *gossh.ServerConn, newChan gossh.NewChannel, ctx *sshContext) { ch, reqs, err := newChan.Accept() if err != nil { + // TODO: trigger event callback return } - sess := srv.newSession(conn, ch) + sess := srv.newSession(conn, ch, ctx) sess.handleRequests(reqs) } -func (srv *Server) newSession(conn *gossh.ServerConn, ch gossh.Channel) *session { +func (srv *Server) newSession(conn *gossh.ServerConn, ch gossh.Channel, ctx *sshContext) *session { sess := &session{ Channel: ch, conn: conn, handler: srv.Handler, ptyCb: srv.PtyCallback, + ctx: ctx, } return sess } diff --git a/session.go b/session.go index 8228ce1..8e88d55 100644 --- a/session.go +++ b/session.go @@ -2,6 +2,7 @@ package ssh import ( "bytes" + "context" "errors" "fmt" "net" @@ -43,6 +44,15 @@ type Session interface { // used it will return nil. PublicKey() PublicKey + // Context returns the connection's context. The returned context is always + // non-nil and holds the same data as the Context passed into auth + // handlers and callbacks. + Context() context.Context + + // Permissions returns a copy of the Permissions object that was available for + // setup in the auth handlers via the Context. + Permissions() Permissions + // Pty returns PTY information, a channel of window size changes, and a boolean // of whether or not a PTY was accepted for this session. Pty() (Pty, <-chan Window, bool) @@ -61,6 +71,7 @@ type session struct { env []string ptyCb PtyCallback cmd []string + ctx *sshContext } func (sess *session) Write(p []byte) (n int, err error) { @@ -80,18 +91,18 @@ func (sess *session) Write(p []byte) (n int, err error) { } func (sess *session) PublicKey() PublicKey { - if sess.conn.Permissions == nil { - return nil - } - s, ok := sess.conn.Permissions.Extensions["_publickey"] - if !ok { - return nil - } - key, err := ParsePublicKey([]byte(s)) - if err != nil { - return nil - } - return key + return sess.ctx.Value(ContextKeyPublicKey).(PublicKey) +} + +func (sess *session) Permissions() Permissions { + // use context permissions because its properly + // wrapped and easier to dereference + perms := sess.ctx.Value(ContextKeyPermissions).(*Permissions) + return *perms +} + +func (sess *session) Context() context.Context { + return sess.ctx.Context } func (sess *session) Exit(code int) error { @@ -163,22 +174,25 @@ func (sess *session) handleRequests(reqs <-chan *gossh.Request) { req.Reply(false, nil) continue } + ptyReq, ok := parsePtyRequest(req.Payload) + if !ok { + req.Reply(false, nil) + continue + } if sess.ptyCb != nil { - ok := sess.ptyCb(sess.conn.User(), &Permissions{sess.conn.Permissions}) + ok := sess.ptyCb(sess.ctx, ptyReq) if !ok { req.Reply(false, nil) continue } } - ptyReq, ok := parsePtyRequest(req.Payload) - if ok { - sess.pty = &ptyReq - sess.winch = make(chan Window, 1) - sess.winch <- ptyReq.Window - defer func() { - close(sess.winch) - }() - } + sess.pty = &ptyReq + sess.winch = make(chan Window, 1) + sess.winch <- ptyReq.Window + defer func() { + // when reqs is closed + close(sess.winch) + }() req.Reply(ok, nil) case "window-change": if sess.pty == nil { diff --git a/session_test.go b/session_test.go index 83705b1..c68a242 100644 --- a/session_test.go +++ b/session_test.go @@ -11,15 +11,14 @@ import ( ) func (srv *Server) serveOnce(l net.Listener) error { - config, err := srv.makeConfig() - if err != nil { + if err := srv.ensureHostSigner(); err != nil { return err } conn, e := l.Accept() if e != nil { return e } - srv.handleConn(conn, config) + srv.handleConn(conn) return nil } @@ -63,6 +62,7 @@ func newTestSession(t *testing.T, srv *Server, cfg *gossh.ClientConfig) (*gossh. } func TestStdout(t *testing.T) { + t.Parallel() testBytes := []byte("Hello world\n") session, cleanup := newTestSession(t, &Server{ Handler: func(s Session) { @@ -81,6 +81,7 @@ func TestStdout(t *testing.T) { } func TestStderr(t *testing.T) { + t.Parallel() testBytes := []byte("Hello world\n") session, cleanup := newTestSession(t, &Server{ Handler: func(s Session) { @@ -99,6 +100,7 @@ func TestStderr(t *testing.T) { } func TestStdin(t *testing.T) { + t.Parallel() testBytes := []byte("Hello world\n") session, cleanup := newTestSession(t, &Server{ Handler: func(s Session) { @@ -118,6 +120,7 @@ func TestStdin(t *testing.T) { } func TestUser(t *testing.T) { + t.Parallel() testUser := []byte("progrium") session, cleanup := newTestSession(t, &Server{ Handler: func(s Session) { @@ -138,6 +141,7 @@ func TestUser(t *testing.T) { } func TestDefaultExitStatusZero(t *testing.T) { + t.Parallel() session, cleanup := newTestSession(t, &Server{ Handler: func(s Session) { // noop @@ -151,6 +155,7 @@ func TestDefaultExitStatusZero(t *testing.T) { } func TestExplicitExitStatusZero(t *testing.T) { + t.Parallel() session, cleanup := newTestSession(t, &Server{ Handler: func(s Session) { s.Exit(0) @@ -164,6 +169,7 @@ func TestExplicitExitStatusZero(t *testing.T) { } func TestExitStatusNonZero(t *testing.T) { + t.Parallel() session, cleanup := newTestSession(t, &Server{ Handler: func(s Session) { s.Exit(1) @@ -181,6 +187,7 @@ func TestExitStatusNonZero(t *testing.T) { } func TestPty(t *testing.T) { + t.Parallel() term := "xterm" winWidth := 40 winHeight := 80 @@ -214,6 +221,7 @@ func TestPty(t *testing.T) { } func TestPtyResize(t *testing.T) { + t.Parallel() winch0 := Window{40, 80} winch1 := Window{80, 160} winch2 := Window{20, 40} diff --git a/ssh.go b/ssh.go index 0f5f43d..f4b5e37 100644 --- a/ssh.go +++ b/ssh.go @@ -34,16 +34,13 @@ type Option func(*Server) error type Handler func(Session) // PublicKeyHandler is a callback for performing public key authentication. -type PublicKeyHandler func(user string, key PublicKey) bool +type PublicKeyHandler func(ctx Context, key PublicKey) bool // PasswordHandler is a callback for performing password authentication. -type PasswordHandler func(user, password string) bool - -// PermissionsCallback is a hook for setting up user permissions. -type PermissionsCallback func(user string, permissions *Permissions) error +type PasswordHandler func(ctx Context, password string) bool // PtyCallback is a hook for allowing PTY sessions. -type PtyCallback func(user string, permissions *Permissions) bool +type PtyCallback func(ctx Context, pty Pty) bool // Window represents the size of a PTY window. type Window struct { diff --git a/wrap.go b/wrap.go index 6da45ee..d1f2b16 100644 --- a/wrap.go +++ b/wrap.go @@ -10,8 +10,7 @@ type PublicKey interface { // The Permissions type holds fine-grained permissions that are specific to a // user or a specific authentication method for a user. Permissions, except for // "source-address", must be enforced in the server application layer, after -// successful authentication. The Permissions are passed on in ServerConn so a -// server implementation can honor them. +// successful authentication. type Permissions struct { *gossh.Permissions }