Skip to content

Commit

Permalink
Refactoring server so it can be embedded.
Browse files Browse the repository at this point in the history
Also adding an account signing request option that causes self signed
account jwt to be forwarded to signing via a nats service.

Signed-off-by: Matthias Hanel <mh@synadia.com>
  • Loading branch information
matthiashanel committed Jun 26, 2020
1 parent 8fd8d24 commit 575934d
Show file tree
Hide file tree
Showing 19 changed files with 928 additions and 638 deletions.
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ build: fmt check compile

fmt:
#misspell -locale US .
gofmt -s -w main.go
gofmt -s -w server/conf/*.go
gofmt -s -w server/core/*.go
gofmt -s -w server/store/*.go
Expand Down Expand Up @@ -42,4 +43,4 @@ fasttest:
scripts/cov.sh

failfast:
go test -tags test -race -failfast ./...
go test -tags test -race -failfast ./...
42 changes: 13 additions & 29 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,8 @@ func main() {
flags.Creds = expandPath(flags.Creds)
flags.Directory = expandPath(flags.Directory)

server = core.NewAccountServer()
if err := server.InitializeFromFlags(flags); err != nil {
if server.Logger() != nil {
logStopExit := func(server *core.AccountServer, err error) {
if _, ok := server.Logger().(*core.NilLogger); !ok {
server.Logger().Errorf("%s", err.Error())
} else {
log.Printf("%s", err.Error())
Expand All @@ -77,6 +76,11 @@ func main() {
os.Exit(1)
}

server = core.NewAccountServer()
if err := server.InitializeFromFlags(flags); err != nil {
logStopExit(server, err)
}

go func() {
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGHUP)
Expand All @@ -85,52 +89,32 @@ func main() {
signal := <-sigChan

if signal == os.Interrupt {
if server.Logger() != nil {
if _, ok := server.Logger().(*core.NilLogger); !ok {
fmt.Println() // clear the line for the control-C
server.Logger().Noticef("received sig-interrupt, shutting down")
}
server.Logger().Noticef("received sig-interrupt, shutting down")
server.Stop()
os.Exit(0)
}

if signal == syscall.SIGHUP {
if server.Logger() != nil {
server.Logger().Errorf("received sig-hup, restarting")
}
server.Logger().Errorf("received sig-hup, restarting")
server.Stop()
server := core.NewAccountServer()

if err := server.InitializeFromFlags(flags); err != nil {
if server.Logger() != nil {
server.Logger().Errorf("%s", err.Error())
} else {
log.Printf("%s", err.Error())
}
server.Stop()
os.Exit(1)
logStopExit(server, err)
}

if err := server.Start(); err != nil {
if server.Logger() != nil {
server.Logger().Errorf("%s", err.Error())
} else {
log.Printf("%s", err.Error())
}
server.Stop()
os.Exit(1)
logStopExit(server, err)
}
}
}
}()

if err := core.Run(server); err != nil {
if server.Logger() != nil {
server.Logger().Errorf("%s", err.Error())
} else {
log.Printf("%s", err.Error())
}
server.Stop()
os.Exit(1)
logStopExit(server, err)
}

// exit main but keep running goroutines
Expand Down
3 changes: 3 additions & 0 deletions server/conf/conf.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ type AccountServerConfig struct {

OperatorJWTPath string
SystemAccountJWTPath string
SignRequestSubject string
SignRequestTimeout int //milliseconds

// Below options are only to copy jwt from an old account server for initialization
Primary string
Expand Down Expand Up @@ -111,5 +113,6 @@ func DefaultServerConfig() *AccountServerConfig {
}, // in memory store
ReplicationTimeout: 5000,
MaxReplicationPack: 10000,
SignRequestTimeout: 1000,
}
}
226 changes: 11 additions & 215 deletions server/core/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,10 @@
package core

import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"regexp"
"strconv"
"strings"
"time"

"github.com/julienschmidt/httprouter"
"github.com/nats-io/jwt/v2"
)

// http headers
Expand All @@ -41,24 +33,18 @@ const (
)

// JWTHelp handles get requests for JWT help
func (server *AccountServer) JWTHelp(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
server.logger.Tracef("%s: %s", r.RemoteAddr, r.URL.String())
func (h *JwtHandler) JWTHelp(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
h.logger.Tracef("%s: %s", r.RemoteAddr, r.URL.String())
w.Header().Add(ContentType, TextPlain)
w.WriteHeader(http.StatusOK)
w.Write([]byte(jwtAPIHelp))
}

// HealthZ returns a status OK
func (server *AccountServer) HealthZ(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
server.logger.Tracef("%s: %s", r.RemoteAddr, r.URL.String())
w.WriteHeader(http.StatusOK)
}

// GetOperatorJWT returns the known operator JWT
func (server *AccountServer) GetOperatorJWT(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
server.logger.Tracef("%s: %s", r.RemoteAddr, r.URL.String())
func (h *JwtHandler) GetOperatorJWT(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
h.logger.Tracef("%s: %s", r.RemoteAddr, r.URL.String())

if server.operatorJWT == "" {
if h.operatorJWT == "" {
w.WriteHeader(http.StatusBadRequest)
return
}
Expand All @@ -67,218 +53,28 @@ func (server *AccountServer) GetOperatorJWT(w http.ResponseWriter, r *http.Reque
text := strings.ToLower(r.URL.Query().Get("text")) == "true"

if text {
server.writeJWTAsText(w, "", server.operatorJWT)
h.writeJWTAsText(w, h.operatorSubject, h.operatorJWT)
return
}

if decode {
server.writeDecodedJWT(w, "", server.operatorJWT)
h.writeDecodedJWT(w, h.operatorSubject, h.operatorJWT)
return
}

w.Header().Add(ContentType, ApplicationJWT)
w.WriteHeader(http.StatusOK)
w.Write([]byte(server.operatorJWT))
}

func (server *AccountServer) sendErrorResponse(httpStatus int, msg string, account string, err error, w http.ResponseWriter) error {
account = ShortKey(account)
if err != nil {
if account != "" {
server.logger.Errorf("%s - %s - %s", account, msg, err.Error())
} else {
server.logger.Errorf("%s - %s", msg, err.Error())
}
} else {
if account != "" {
server.logger.Errorf("%s - %s", account, msg)
} else {
server.logger.Errorf("%s", msg)
}
}

w.Header().Set(ContentType, TextPlain)
w.WriteHeader(httpStatus)
fmt.Fprintln(w, msg)
return err
w.Write([]byte(h.operatorJWT))
}

func (server *AccountServer) writeJWTAsText(w http.ResponseWriter, pubKey string, theJWT string) {
func (h *JwtHandler) writeJWTAsText(w http.ResponseWriter, pubKey string, theJWT string) {
w.Header().Add(ContentType, TextPlain)
w.WriteHeader(http.StatusOK)
_, err := w.Write([]byte(theJWT))

if err != nil {
server.logger.Errorf("error writing JWT as text for %s - %s", ShortKey(pubKey), err.Error())
h.logger.Errorf("error writing JWT as text for %s - %s", ShortKey(pubKey), err.Error())
} else {
server.logger.Tracef("returning JWT as text for - %s", ShortKey(pubKey))
}
}

// UnescapedIndentedMarshal handle indention for decoded JWTs
func UnescapedIndentedMarshal(v interface{}, prefix, indent string) ([]byte, error) {
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
enc.SetEscapeHTML(false)
enc.SetIndent(prefix, indent)

err := enc.Encode(v)
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}

func (server *AccountServer) writeDecodedJWT(w http.ResponseWriter, pubKey string, theJWT string) {

claim, err := jwt.DecodeGeneric(theJWT)
if err != nil {
server.sendErrorResponse(http.StatusInternalServerError, "error decoding claim", pubKey, err, w)
return
}

parts := strings.Split(theJWT, ".")
head := parts[0]
sig := parts[2]
headerString, err := base64.RawURLEncoding.DecodeString(head)
if err != nil {
server.sendErrorResponse(http.StatusInternalServerError, "error decoding claim header", pubKey, err, w)
return
}
header := jwt.Header{}
if err := json.Unmarshal(headerString, &header); err != nil {
server.sendErrorResponse(http.StatusInternalServerError, "error unmarshaling claim header", pubKey, err, w)
return
}

headerJSON, err := UnescapedIndentedMarshal(header, "", " ")
if err != nil {
server.sendErrorResponse(http.StatusInternalServerError, "error marshaling claim header", pubKey, err, w)
return
h.logger.Tracef("returning JWT as text for - %s", ShortKey(pubKey))
}

claimJSON, err := UnescapedIndentedMarshal(claim, "", " ")
if err != nil {
server.sendErrorResponse(http.StatusInternalServerError, "error marshaling claim", pubKey, err, w)
return
}

var subErr error

r := regexp.MustCompile(`"token":.*?"(.*?)",`)
claimJSON = r.ReplaceAllFunc(claimJSON, func(m []byte) []byte {
if subErr != nil {
return []byte(fmt.Sprintf(`"token": <bad token - %s>,`, subErr.Error()))
}

tokenStr := string(m)

tokenStr = tokenStr[0 : len(tokenStr)-2] // strip the ",
index := strings.LastIndex(tokenStr, "\"")
tokenStr = tokenStr[index+1:]

activateToken, subErr := jwt.DecodeActivationClaims(tokenStr)

if subErr == nil {
token, subErr := UnescapedIndentedMarshal(activateToken, " ", " ")

tokenStr = string(token)
tokenStr = strings.TrimSpace(tokenStr) // get rid of leading whitespace

if subErr == nil {
decoded := fmt.Sprintf(`"token": %s,`, tokenStr)
return []byte(decoded)
}
}

return []byte(fmt.Sprintf(`"token": <bad token - %s>,`, subErr.Error()))
})

if subErr != nil {
server.sendErrorResponse(http.StatusInternalServerError, "error marshaling tokens", pubKey, subErr, w)
return
}

r = regexp.MustCompile(`"iat":.*?(\d?),`)
claimJSON = r.ReplaceAllFunc(claimJSON, func(m []byte) []byte {
if subErr != nil {
return []byte(fmt.Sprintf(`"iat": <parse error - %s>,`, subErr.Error()))
}

var iat int
iatStr := string(m)
iatStr = iatStr[0 : len(iatStr)-1] // strip the ,
index := strings.LastIndex(iatStr, " ")
iatStr = iatStr[index+1:]
iat, subErr = strconv.Atoi(iatStr)

if subErr != nil {
return []byte(fmt.Sprintf(`"iat": <parse error - %s>,`, subErr.Error()))
}

formatted := UnixToDate(int64(iat))
decoded := fmt.Sprintf(`"iat": %s (%s),`, iatStr, formatted)

return []byte(decoded)
})

r = regexp.MustCompile(`"exp":.*?(\d?),`)
claimJSON = r.ReplaceAllFunc(claimJSON, func(m []byte) []byte {
if subErr != nil {
return []byte(fmt.Sprintf(`"exp": <parse error - %s>,`, subErr.Error()))
}

var iat int
iatStr := string(m)
iatStr = iatStr[0 : len(iatStr)-1] // strip the ,
index := strings.LastIndex(iatStr, " ")
iatStr = iatStr[index+1:]
iat, subErr = strconv.Atoi(iatStr)

if subErr != nil {
return []byte(fmt.Sprintf(`"exp": <parse error - %s>,`, subErr.Error()))
}

formatted := UnixToDate(int64(iat))
decoded := fmt.Sprintf(`"exp": %s (%s),`, iatStr, formatted)

return []byte(decoded)
})

if subErr != nil {
server.sendErrorResponse(http.StatusInternalServerError, "error marshaling tokens", pubKey, subErr, w)
return
}

newLineBytes := []byte("\r\n")
jsonBuff := []byte{}
jsonBuff = append(jsonBuff, headerJSON...)
jsonBuff = append(jsonBuff, newLineBytes...)
jsonBuff = append(jsonBuff, claimJSON...)
jsonBuff = append(jsonBuff, newLineBytes...)
jsonBuff = append(jsonBuff, []byte(sig)...)
// if this last new line is not set curls will show a '%' in the output.
jsonBuff = append(jsonBuff, '\n')

w.Header().Add(ContentType, TextPlain)
w.WriteHeader(http.StatusOK)
_, err = w.Write(jsonBuff)

if err != nil {
server.logger.Errorf("error writing decoded JWT as text for %s - %s", ShortKey(pubKey), err.Error())
} else {
server.logger.Tracef("returning decoded JWT as text for - %s", ShortKey(pubKey))
}
}

func (server *AccountServer) cacheControlForExpiration(pubKey string, expires int64) string {
now := time.Now().UTC()
maxAge := int64(time.Unix(expires, 0).Sub(now).Seconds())
stale := int64(60 * 60) // One hour
return fmt.Sprintf("max-age=%d, stale-while-revalidate=%d, stale-if-error=%d", maxAge, stale, stale)
}

func (server *AccountServer) loadJWT(pubKey string, path string) (string, error) {
server.logger.Noticef("%s:%s", pubKey, path)
return server.jwtStore.Load(pubKey)
}
Loading

0 comments on commit 575934d

Please sign in to comment.