Skip to content

Commit

Permalink
compiler/internal/codegen: Allow header only auth
Browse files Browse the repository at this point in the history
Fix "unused variable" bug when having only headers or query parameter in auth data type.
  • Loading branch information
ekerfelt authored and eandre committed May 30, 2022
1 parent cfe1a46 commit 8edd9ec
Show file tree
Hide file tree
Showing 7 changed files with 804 additions and 6 deletions.
14 changes: 8 additions & 6 deletions compiler/internal/codegen/codegen_main.go
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,9 @@ func (b *Builder) decodeRequest(requestDecoder *gocodegen.MarshallingCodeWrapper
}

func (b *Builder) decodeHeaders(g *Group, pos gotoken.Pos, requestDecoder *gocodegen.MarshallingCodeWrapper, params []*encoding.ParameterEncoding) {
if len(params) == 0 {
return
}
g.Comment("Decode Headers")
g.Id("h").Op(":=").Id("req").Dot("Header")
for _, f := range params {
Expand All @@ -495,6 +498,9 @@ func (b *Builder) decodeHeaders(g *Group, pos gotoken.Pos, requestDecoder *gocod
}

func (b *Builder) decodeQueryString(g *Group, pos gotoken.Pos, requestDecoder *gocodegen.MarshallingCodeWrapper, params []*encoding.ParameterEncoding) {
if len(params) == 0 {
return
}
g.Comment("Decode Query String")
g.Id("qs").Op(":=").Id("req").Dot("URL").Dot("Query").Call()

Expand All @@ -511,14 +517,10 @@ func (b *Builder) decodeQueryString(g *Group, pos gotoken.Pos, requestDecoder *g
func (b *Builder) decodeRequestParameters(g *Group, rpc *est.RPC, requestDecoder *gocodegen.MarshallingCodeWrapper, req *encoding.RequestEncoding) {

// Decode headers
if len(req.HeaderParameters) > 0 {
b.decodeHeaders(g, rpc.Func.Pos(), requestDecoder, req.HeaderParameters)
}
b.decodeHeaders(g, rpc.Func.Pos(), requestDecoder, req.HeaderParameters)

// Decode QueryString
if len(req.QueryParameters) > 0 {
b.decodeQueryString(g, rpc.Func.Pos(), requestDecoder, req.QueryParameters)
}
b.decodeQueryString(g, rpc.Func.Pos(), requestDecoder, req.QueryParameters)

// Decode Body
if len(req.BodyParameters) > 0 {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,326 @@
// main code
package main

import (
"context"
"encore.app/svc"
auth "encore.dev/beta/auth"
"encore.dev/beta/errs"
"encore.dev/runtime"
"encore.dev/runtime/config"
serde "encore.dev/runtime/serde"
"fmt"
"github.com/json-iterator/go"
"github.com/julienschmidt/httprouter"
"io"
"io/ioutil"
"net/http"
"net/url"
"os"
"reflect"
"strconv"
"strings"
_ "unsafe"
)

var json = jsoniter.Config{
EscapeHTML: false,
IndentionStep: config.JsonIndentStepForResponses(),
SortMapKeys: true,
ValidateJsonRawMessage: true,
}.Froze()

func __encore_svc_Eight(w http.ResponseWriter, req *http.Request, ps httprouter.Params) {
ctx := req.Context()
runtime.BeginOperation()
defer runtime.FinishOperation()

var err error
dec := &marshaller{}
// Decode request
if value, err := url.PathUnescape(ps[0].Value); err == nil {
ps[0].Value = value
}
p0 := dec.ToString("bar", ps[0].Value, true)
if value, err := url.PathUnescape(ps[1].Value); err == nil {
ps[1].Value = value
}
p1 := dec.ToString("baz", ps[1].Value, true)
inputs, _ := runtime.SerializeInputs(p0, p1)

params := &svc.FooParams{}
switch m := req.Method; m {
case "POST":
// Decode JSON Body
payload := dec.Body(req.Body)
iter := jsoniter.ParseBytes(json, payload)

for iter.ReadObjectCB(func(_ *jsoniter.Iterator, key string) bool {
switch strings.ToLower(key) {
case "name":
dec.ParseJSON("Name", iter, &params.Name)
default:
_ = iter.SkipAndReturnBytes()
}
return true
}) {
}

default:
panic("HTTP method is not supported")
}
// Add trace info
jsonParams, err := json.Marshal(params)
if err != nil {
errs.HTTPError(w, errs.B().Code(errs.Internal).Msg("internal error").Err())
return
}
inputs = append(inputs, jsonParams)

uid, authData, proceed := __encore_authenticate(w, req, true, "svc", "Eight")
if !proceed {
return
}

err = runtime.BeginRequest(ctx, runtime.RequestData{
AuthData: authData,
Endpoint: "Eight",
EndpointExprIdx: 2,
Inputs: inputs,
Path: req.URL.Path,
PathSegments: ps,
Service: "svc",
Type: runtime.RPCCall,
UID: uid,
})
if err != nil {
errs.HTTPError(w, errs.B().Code(errs.Internal).Msg("internal error").Err())
return
}
if dec.LastError != nil {
err := dec.LastError
runtime.FinishRequest(nil, err)
errs.HTTPError(w, err)
return
}

// Call the endpoint
defer func() {
// Catch handler panic
if e := recover(); e != nil {
err := errs.B().Code(errs.Internal).Msgf("panic handling request: %v", e).Err()
runtime.FinishRequest(nil, err)
errs.HTTPError(w, err)
}
}()
resp, respErr := svc.Eight(req.Context(), p0, p1, params)
if respErr != nil {
respErr = errs.Convert(respErr)
runtime.FinishRequest(nil, respErr)
errs.HTTPError(w, respErr)
return
}

// Serialize the response
respData := []byte("null\n")
if resp != nil {
// Encode JSON body
respData, err = serde.SerializeJSONFunc(json, func(ser *serde.JSONSerializer) {
ser.WriteField("Message", resp.Message, false)
})
if err != nil {
marshalErr := errs.WrapCode(err, errs.Internal, "failed to marshal response")
runtime.FinishRequest(nil, marshalErr)
errs.HTTPError(w, marshalErr)
return
}
respData = append(respData, '\n')
}

// Record tracing data
output := [][]byte{respData}
runtime.FinishRequest(output, nil)

// Write response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200)
w.Write(respData)
}

// loadConfig registers the Encore services.
//go:linkname loadConfig encore.dev/runtime/config.loadConfig
func loadConfig() (*config.Config, error) {
services := []*config.Service{{
Endpoints: []*config.Endpoint{{
Access: config.Auth,
Handler: __encore_svc_Eight,
Methods: []string{"POST"},
Name: "Eight",
Path: "/eight/:bar/:baz",
Raw: false,
}},
Name: "svc",
RelPath: "svc",
}}
static := &config.Static{
AppCommit: config.CommitInfo{
Revision: "",
Uncommitted: false,
},
AuthData: reflect.TypeOf((*svc.AuthData)(nil)),
EncoreCompiler: "test",
Services: services,
TestService: "",
Testing: false,
}
return &config.Config{
Runtime: config.ParseRuntime(getAndClearEnv("ENCORE_RUNTIME_CONFIG")),
Secrets: config.ParseSecrets(getAndClearEnv("ENCORE_APP_SECRETS")),
Static: static,
}, nil
}

func main() {
if err := runtime.ListenAndServe(); err != nil {
runtime.Logger().Fatal().Err(err).Msg("could not listen and serve")
}
}

// getAndClearEnv gets an env variable and unsets it.
func getAndClearEnv(env string) string {
val := os.Getenv(env)
os.Unsetenv(env)
return val
}

type validationDetails struct {
Field string `json:"field"`
Err string `json:"err"`
}

func (validationDetails) ErrDetails() {}

// __encore_authenticate authenticates a request.
// It reports the user id, user data, and whether or not to proceed with the request.
// If requireAuth is false, it reports ("", nil, true) on authentication failure.
func __encore_authenticate(w http.ResponseWriter, req *http.Request, requireAuth bool, svcName, rpcName string) (uid auth.UID, authData interface{}, proceed bool) {
param, err := __encore_resolveAuthParam(req)
if err != nil {
if requireAuth {
runtime.Logger().Info().Str("service", svcName).Str("endpoint", rpcName).Msg("rejecting request due to missing auth")
errs.HTTPError(w, errs.B().Code(errs.Unauthenticated).Msg("invalid auth param").Err())
return "", nil, false
}
return "", nil, true
}

uid, authData, err = __encore_validateToken(req.Context(), param)
if errs.Code(err) == errs.Unauthenticated && !requireAuth {
return "", nil, true
} else if err != nil {
errs.HTTPError(w, err)
return "", nil, false
}
return uid, authData, true
}

// __encore_resolveAuthParam resolves the auth parameters from the http request
// or returns an error if auth params cannot be found
func __encore_resolveAuthParam(req *http.Request) (param *svc.AuthHeaders, err error) {
params := &svc.AuthHeaders{}
dec := &marshaller{}
// Decode Headers
h := req.Header
params.Header1 = h.Get("header1")
params.Header2 = dec.ToInt("header2", h.Get("header2"), false)

if dec.LastError != nil {
return nil, dec.LastError
}
return params, nil
}

// __encore_validateToken validates an auth token.
func __encore_validateToken(ctx context.Context, param *svc.AuthHeaders) (uid auth.UID, authData interface{}, authErr error) {
done := make(chan struct{})
paramStr, err := json.MarshalToString(param)
if err != nil {
return "", nil, err
}
call, err := runtime.BeginAuth(3, paramStr)
if err != nil {
return "", nil, err
}

go func() {
defer close(done)
authErr = call.BeginReq(ctx, runtime.RequestData{
Endpoint: "AuthHandler",
EndpointExprIdx: 3,
Inputs: [][]byte{[]byte(paramStr)},
Service: "svc",
Type: runtime.AuthHandler,
})
if authErr != nil {
return
}
defer func() {
if err2 := recover(); err2 != nil {
authErr = errs.B().Code(errs.Internal).Msgf("auth handler panicked: %v", err2).Err()
call.FinishReq(nil, authErr)
}
}()
uid, authData, authErr = svc.AuthHandler(ctx, param)
serialized, _ := runtime.SerializeInputs(uid, authData)
if authErr != nil {
call.FinishReq(nil, authErr)
} else {
call.FinishReq(serialized, nil)
}
}()
<-done
call.Finish(uid, authErr)
return uid, authData, authErr
}

// marshaller is used to serialize request data into strings and deserialize response data from strings
type marshaller struct {
LastError error // The last error that occurred
}

func (e *marshaller) ToString(field string, s string, required bool) (v string) {
if !required && s == "" {
return
}
return s
}

func (e *marshaller) ToInt(field string, s string, required bool) (v int) {
if !required && s == "" {
return
}
x, err := strconv.ParseInt(s, 10, 64)
e.setErr("invalid parameter", field, err)
return int(x)
}

// setErr sets the last error within the object if one is not already set
func (e *marshaller) setErr(msg, field string, err error) {
if err != nil && e.LastError == nil {
e.LastError = fmt.Errorf("%s: %s: %w", field, msg, err)
}
}

func (d *marshaller) Body(body io.Reader) (payload []byte) {
payload, err := ioutil.ReadAll(body)
if err == nil && len(payload) == 0 {
d.setErr("missing request body", "request_body", fmt.Errorf("missing request body"))
} else if err != nil {
d.setErr("could not parse request body", "request_body", err)
}
return payload
}
func (d *marshaller) ParseJSON(field string, iter *jsoniter.Iterator, dst interface{}) {
iter.ReadVal(dst)
d.setErr("invalid json parameter", field, iter.Error)
}
Loading

0 comments on commit 8edd9ec

Please sign in to comment.