Skip to content

Commit 4e8766f

Browse files
nsarychevsapessi
authored andcommitted
Add API GW Proxy context and Lambda Context to http.Request context (awslabs#33)
This looks good. Thanks for the contribution. I'll update the readme to document the new `ProxyWithContext` method. * Pass context.Context from Lambda runtime to http.Request. Fixes awslabs#27 * Move API GW context to context.Context * Remove api GW context header * Pass APIGatewayContext in Request.Context * Refactor to way better * Cleanup * Add comments * Add tests * PR comment fixes * Inbtroduce two paths * Refactor adapters * Rename for consistency
1 parent eb5a49d commit 4e8766f

File tree

15 files changed

+272
-42
lines changed

15 files changed

+272
-42
lines changed

chi/adapter.go

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package chiadapter
55

66
import (
7+
"context"
78
"net/http"
89

910
"github.com/aws/aws-lambda-go/events"
@@ -29,9 +30,21 @@ func New(chi *chi.Mux) *ChiLambda {
2930

3031
// Proxy receives an API Gateway proxy event, transforms it into an http.Request
3132
// object, and sends it to the chi.Mux for routing.
32-
// It returns a proxy response object gneerated from the http.ResponseWriter.
33+
// It returns a proxy response object generated from the http.ResponseWriter.
3334
func (g *ChiLambda) Proxy(req events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, error) {
3435
chiRequest, err := g.ProxyEventToHTTPRequest(req)
36+
return g.proxyInternal(chiRequest, err)
37+
}
38+
39+
// ProxyWithContext receives context and an API Gateway proxy event,
40+
// transforms them into an http.Request object, and sends it to the chi.Mux for routing.
41+
// It returns a proxy response object generated from the http.ResponseWriter.
42+
func (g *ChiLambda) ProxyWithContext(ctx context.Context, req events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, error) {
43+
chiRequest, err := g.EventToRequestWithContext(ctx, req)
44+
return g.proxyInternal(chiRequest, err)
45+
}
46+
47+
func (g *ChiLambda) proxyInternal(chiRequest *http.Request, err error) (events.APIGatewayProxyResponse, error) {
3548

3649
if err != nil {
3750
return core.GatewayTimeout(), core.NewLoggedError("Could not convert proxy event to request: %v", err)

chi/chilambda_test.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
package chiadapter_test
22

33
import (
4+
"context"
45
"log"
56
"net/http"
67

78
"github.com/aws/aws-lambda-go/events"
8-
"github.com/awslabs/aws-lambda-go-api-proxy/chi"
9+
chiadapter "github.com/awslabs/aws-lambda-go-api-proxy/chi"
910
"github.com/go-chi/chi"
1011

1112
. "github.com/onsi/ginkgo"
@@ -29,10 +30,14 @@ var _ = Describe("ChiLambda tests", func() {
2930
HTTPMethod: "GET",
3031
}
3132

32-
resp, err := adapter.Proxy(req)
33+
resp, err := adapter.ProxyWithContext(context.Background(), req)
3334

3435
Expect(err).To(BeNil())
3536
Expect(resp.StatusCode).To(Equal(200))
37+
38+
resp, err = adapter.Proxy(req)
39+
Expect(err).To(BeNil())
40+
Expect(resp.StatusCode).To(Equal(200))
3641
})
3742
})
3843
})

core/request.go

Lines changed: 71 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ package core
44

55
import (
66
"bytes"
7+
"context"
78
"encoding/base64"
89
"encoding/json"
910
"errors"
@@ -15,6 +16,7 @@ import (
1516
"strings"
1617

1718
"github.com/aws/aws-lambda-go/events"
19+
"github.com/aws/aws-lambda-go/lambdacontext"
1820
)
1921

2022
// CustomHostVariable is the name of the environment variable that contains
@@ -102,13 +104,33 @@ func (r *RequestAccessor) StripBasePath(basePath string) string {
102104
return newBasePath
103105
}
104106

105-
// ProxyEventToHTTPRequest converts an API Gateway proxy event into an
106-
// http.Request object.
107-
// Returns the populated request with an additional two custom headers for the
108-
// stage variables and API Gateway context. To access these properties use
109-
// the GetAPIGatewayStageVars and GetAPIGatewayContext method of the RequestAccessor
110-
// object.
107+
// ProxyEventToHTTPRequest converts an API Gateway proxy event into a http.Request object.
108+
// Returns the populated http request with additional two custom headers for the stage variables and API Gateway context.
109+
// To access these properties use the GetAPIGatewayStageVars and GetAPIGatewayContext method of the RequestAccessor object.
111110
func (r *RequestAccessor) ProxyEventToHTTPRequest(req events.APIGatewayProxyRequest) (*http.Request, error) {
111+
httpRequest, err := r.EventToRequest(req)
112+
if err != nil {
113+
log.Println(err)
114+
return nil, err
115+
}
116+
return addToHeader(httpRequest, req)
117+
}
118+
119+
// EventToRequestWithContext converts an API Gateway proxy event and context into an http.Request object.
120+
// Returns the populated http request with lambda context, stage variables and APIGatewayProxyRequestContext as part of its context.
121+
// Access those using GetAPIGatewayContextFromContext, GetStageVarsFromContext and GetRuntimeContextFromContext functions in this package.
122+
func (r *RequestAccessor) EventToRequestWithContext(ctx context.Context, req events.APIGatewayProxyRequest) (*http.Request, error) {
123+
httpRequest, err := r.EventToRequest(req)
124+
if err != nil {
125+
log.Println(err)
126+
return nil, err
127+
}
128+
return addToContext(ctx, httpRequest, req), nil
129+
}
130+
131+
// EventToRequest converts an API Gateway proxy event into an http.Request object.
132+
// Returns the populated request maintaining headers
133+
func (r *RequestAccessor) EventToRequest(req events.APIGatewayProxyRequest) (*http.Request, error) {
112134
decodedBody := []byte(req.Body)
113135
if req.IsBase64Encoded {
114136
base64Body, err := base64.StdEncoding.DecodeString(req.Body)
@@ -157,23 +179,57 @@ func (r *RequestAccessor) ProxyEventToHTTPRequest(req events.APIGatewayProxyRequ
157179
log.Println(err)
158180
return nil, err
159181
}
160-
161182
for h := range req.Headers {
162183
httpRequest.Header.Add(h, req.Headers[h])
163184
}
185+
return httpRequest, nil
186+
}
164187

165-
apiGwContext, err := json.Marshal(req.RequestContext)
188+
func addToHeader(req *http.Request, apiGwRequest events.APIGatewayProxyRequest) (*http.Request, error) {
189+
stageVars, err := json.Marshal(apiGwRequest.StageVariables)
166190
if err != nil {
167-
log.Println("Could not Marshal API GW context for custom header")
191+
log.Println("Could not marshal stage variables for custom header")
168192
return nil, err
169193
}
170-
stageVars, err := json.Marshal(req.StageVariables)
194+
req.Header.Add(APIGwStageVarsHeader, string(stageVars))
195+
apiGwContext, err := json.Marshal(apiGwRequest.RequestContext)
171196
if err != nil {
172-
log.Println("Could not marshal stage variables for custom header")
173-
return nil, err
197+
log.Println("Could not Marshal API GW context for custom header")
198+
return req, err
174199
}
175-
httpRequest.Header.Add(APIGwContextHeader, string(apiGwContext))
176-
httpRequest.Header.Add(APIGwStageVarsHeader, string(stageVars))
200+
req.Header.Add(APIGwContextHeader, string(apiGwContext))
201+
return req, nil
202+
}
177203

178-
return httpRequest, nil
204+
func addToContext(ctx context.Context, req *http.Request, apiGwRequest events.APIGatewayProxyRequest) *http.Request {
205+
lc, _ := lambdacontext.FromContext(ctx)
206+
rc := requestContext{lambdaContext: lc, gatewayProxyContext: apiGwRequest.RequestContext, stageVars: apiGwRequest.StageVariables}
207+
ctx = context.WithValue(req.Context(), ctxKey{}, rc)
208+
return req.WithContext(ctx)
209+
}
210+
211+
// GetAPIGatewayContextFromContext retrieve APIGatewayProxyRequestContext from context.Context
212+
func GetAPIGatewayContextFromContext(ctx context.Context) (events.APIGatewayProxyRequestContext, bool) {
213+
v, ok := ctx.Value(ctxKey{}).(requestContext)
214+
return v.gatewayProxyContext, ok
215+
}
216+
217+
// GetRuntimeContextFromContext retrieve Lambda Runtime Context from context.Context
218+
func GetRuntimeContextFromContext(ctx context.Context) (*lambdacontext.LambdaContext, bool) {
219+
v, ok := ctx.Value(ctxKey{}).(requestContext)
220+
return v.lambdaContext, ok
221+
}
222+
223+
// GetStageVarsFromContext retrieve stage variables from context
224+
func GetStageVarsFromContext(ctx context.Context) (map[string]string, bool) {
225+
v, ok := ctx.Value(ctxKey{}).(requestContext)
226+
return v.stageVars, ok
227+
}
228+
229+
type ctxKey struct{}
230+
231+
type requestContext struct {
232+
lambdaContext *lambdacontext.LambdaContext
233+
gatewayProxyContext events.APIGatewayProxyRequestContext
234+
stageVars map[string]string
179235
}

core/request_test.go

Lines changed: 70 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
package core_test
22

33
import (
4+
"context"
45
"encoding/base64"
56
"io/ioutil"
67
"math/rand"
78
"os"
89

910
"github.com/aws/aws-lambda-go/events"
11+
"github.com/aws/aws-lambda-go/lambdacontext"
1012
"github.com/awslabs/aws-lambda-go-api-proxy/core"
1113

1214
. "github.com/onsi/ginkgo"
@@ -18,14 +20,15 @@ var _ = Describe("RequestAccessor tests", func() {
1820
accessor := core.RequestAccessor{}
1921
basicRequest := getProxyRequest("/hello", "GET")
2022
It("Correctly converts a basic event", func() {
21-
httpReq, err := accessor.ProxyEventToHTTPRequest(basicRequest)
23+
httpReq, err := accessor.EventToRequestWithContext(context.Background(), basicRequest)
2224
Expect(err).To(BeNil())
2325
Expect("/hello").To(Equal(httpReq.URL.Path))
2426
Expect("GET").To(Equal(httpReq.Method))
2527
})
2628

2729
basicRequest = getProxyRequest("/hello", "get")
2830
It("Converts method to uppercase", func() {
31+
// calling old method to verify reverse compatibility
2932
httpReq, err := accessor.ProxyEventToHTTPRequest(basicRequest)
3033
Expect(err).To(BeNil())
3134
Expect("/hello").To(Equal(httpReq.URL.Path))
@@ -45,7 +48,7 @@ var _ = Describe("RequestAccessor tests", func() {
4548
binaryRequest.IsBase64Encoded = true
4649

4750
It("Decodes a base64 encoded body", func() {
48-
httpReq, err := accessor.ProxyEventToHTTPRequest(binaryRequest)
51+
httpReq, err := accessor.EventToRequestWithContext(context.Background(), binaryRequest)
4952
Expect(err).To(BeNil())
5053
Expect("/hello").To(Equal(httpReq.URL.Path))
5154
Expect("POST").To(Equal(httpReq.Method))
@@ -63,7 +66,7 @@ var _ = Describe("RequestAccessor tests", func() {
6366
"world": {"2", "3"},
6467
}
6568
It("Populates query string correctly", func() {
66-
httpReq, err := accessor.ProxyEventToHTTPRequest(qsRequest)
69+
httpReq, err := accessor.EventToRequestWithContext(context.Background(), qsRequest)
6770
Expect(err).To(BeNil())
6871
Expect("/hello").To(Equal(httpReq.URL.Path))
6972
Expect("GET").To(Equal(httpReq.Method))
@@ -83,7 +86,8 @@ var _ = Describe("RequestAccessor tests", func() {
8386

8487
It("Stips the base path correct", func() {
8588
accessor.StripBasePath("app1")
86-
httpReq, err := accessor.ProxyEventToHTTPRequest(basePathRequest)
89+
httpReq, err := accessor.EventToRequestWithContext(context.Background(), basePathRequest)
90+
8791
Expect(err).To(BeNil())
8892
Expect("/orders").To(Equal(httpReq.URL.Path))
8993
})
@@ -92,6 +96,7 @@ var _ = Describe("RequestAccessor tests", func() {
9296
contextRequest.RequestContext = getRequestContext()
9397

9498
It("Populates context header correctly", func() {
99+
// calling old method to verify reverse compatibility
95100
httpReq, err := accessor.ProxyEventToHTTPRequest(contextRequest)
96101
Expect(err).To(BeNil())
97102
Expect(2).To(Equal(len(httpReq.Header)))
@@ -123,16 +128,49 @@ var _ = Describe("RequestAccessor tests", func() {
123128
contextRequest.RequestContext = getRequestContext()
124129

125130
accessor := core.RequestAccessor{}
131+
// calling old method to verify reverse compatibility
126132
httpReq, err := accessor.ProxyEventToHTTPRequest(contextRequest)
127133
Expect(err).To(BeNil())
128134

129-
context, err := accessor.GetAPIGatewayContext(httpReq)
135+
headerContext, err := accessor.GetAPIGatewayContext(httpReq)
136+
Expect(err).To(BeNil())
137+
Expect(headerContext).ToNot(BeNil())
138+
Expect("x").To(Equal(headerContext.AccountID))
139+
Expect("x").To(Equal(headerContext.RequestID))
140+
Expect("x").To(Equal(headerContext.APIID))
141+
proxyContext, ok := core.GetAPIGatewayContextFromContext(httpReq.Context())
142+
// should fail because using header proxy method
143+
Expect(ok).To(BeFalse())
144+
145+
httpReq, err = accessor.EventToRequestWithContext(context.Background(), contextRequest)
130146
Expect(err).To(BeNil())
131-
Expect(context).ToNot(BeNil())
132-
Expect("x").To(Equal(context.AccountID))
133-
Expect("x").To(Equal(context.RequestID))
134-
Expect("x").To(Equal(context.APIID))
135-
Expect("prod").To(Equal(context.Stage))
147+
proxyContext, ok = core.GetAPIGatewayContextFromContext(httpReq.Context())
148+
Expect(ok).To(BeTrue())
149+
Expect("x").To(Equal(proxyContext.APIID))
150+
Expect("x").To(Equal(proxyContext.RequestID))
151+
Expect("x").To(Equal(proxyContext.APIID))
152+
Expect("prod").To(Equal(proxyContext.Stage))
153+
runtimeContext, ok := core.GetRuntimeContextFromContext(httpReq.Context())
154+
Expect(ok).To(BeTrue())
155+
Expect(runtimeContext).To(BeNil())
156+
157+
lambdaContext := lambdacontext.NewContext(context.Background(), &lambdacontext.LambdaContext{AwsRequestID: "abc123"})
158+
httpReq, err = accessor.EventToRequestWithContext(lambdaContext, contextRequest)
159+
Expect(err).To(BeNil())
160+
161+
headerContext, err = accessor.GetAPIGatewayContext(httpReq)
162+
// should fail as new context method doesn't populate headers
163+
Expect(err).ToNot(BeNil())
164+
proxyContext, ok = core.GetAPIGatewayContextFromContext(httpReq.Context())
165+
Expect(ok).To(BeTrue())
166+
Expect("x").To(Equal(proxyContext.APIID))
167+
Expect("x").To(Equal(proxyContext.RequestID))
168+
Expect("x").To(Equal(proxyContext.APIID))
169+
Expect("prod").To(Equal(proxyContext.Stage))
170+
runtimeContext, ok = core.GetRuntimeContextFromContext(httpReq.Context())
171+
Expect(ok).To(BeTrue())
172+
Expect(runtimeContext).ToNot(BeNil())
173+
Expect("abc123").To(Equal(runtimeContext.AwsRequestID))
136174
})
137175

138176
It("Populates stage variables correctly", func() {
@@ -150,9 +188,29 @@ var _ = Describe("RequestAccessor tests", func() {
150188
Expect(stageVars["var2"]).ToNot(BeNil())
151189
Expect("value1").To(Equal(stageVars["var1"]))
152190
Expect("value2").To(Equal(stageVars["var2"]))
191+
192+
stageVars, ok := core.GetStageVarsFromContext(httpReq.Context())
193+
// not present in context
194+
Expect(ok).To(BeFalse())
195+
196+
httpReq, err = accessor.EventToRequestWithContext(context.Background(), varsRequest)
197+
Expect(err).To(BeNil())
198+
199+
stageVars, err = accessor.GetAPIGatewayStageVars(httpReq)
200+
// should not be in headers
201+
Expect(err).ToNot(BeNil())
202+
203+
stageVars, ok = core.GetStageVarsFromContext(httpReq.Context())
204+
Expect(ok).To(BeTrue())
205+
Expect(2).To(Equal(len(stageVars)))
206+
Expect(stageVars["var1"]).ToNot(BeNil())
207+
Expect(stageVars["var2"]).ToNot(BeNil())
208+
Expect("value1").To(Equal(stageVars["var1"]))
209+
Expect("value2").To(Equal(stageVars["var2"]))
153210
})
154211

155212
It("Populates the default hostname correctly", func() {
213+
156214
basicRequest := getProxyRequest("orders", "GET")
157215
accessor := core.RequestAccessor{}
158216
httpReq, err := accessor.ProxyEventToHTTPRequest(basicRequest)
@@ -167,7 +225,7 @@ var _ = Describe("RequestAccessor tests", func() {
167225
os.Setenv(core.CustomHostVariable, myCustomHost)
168226
basicRequest := getProxyRequest("orders", "GET")
169227
accessor := core.RequestAccessor{}
170-
httpReq, err := accessor.ProxyEventToHTTPRequest(basicRequest)
228+
httpReq, err := accessor.EventToRequestWithContext(context.Background(), basicRequest)
171229
Expect(err).To(BeNil())
172230

173231
Expect(myCustomHost).To(Equal("http://" + httpReq.Host))
@@ -180,7 +238,7 @@ var _ = Describe("RequestAccessor tests", func() {
180238
os.Setenv(core.CustomHostVariable, myCustomHost+"/")
181239
basicRequest := getProxyRequest("orders", "GET")
182240
accessor := core.RequestAccessor{}
183-
httpReq, err := accessor.ProxyEventToHTTPRequest(basicRequest)
241+
httpReq, err := accessor.EventToRequestWithContext(context.Background(), basicRequest)
184242
Expect(err).To(BeNil())
185243

186244
Expect(myCustomHost).To(Equal("http://" + httpReq.Host))

gin/adapter.go

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package ginadapter
55

66
import (
7+
"context"
78
"net/http"
89

910
"github.com/aws/aws-lambda-go/events"
@@ -29,16 +30,28 @@ func New(gin *gin.Engine) *GinLambda {
2930

3031
// Proxy receives an API Gateway proxy event, transforms it into an http.Request
3132
// object, and sends it to the gin.Engine for routing.
32-
// It returns a proxy response object gneerated from the http.ResponseWriter.
33+
// It returns a proxy response object generated from the http.ResponseWriter.
3334
func (g *GinLambda) Proxy(req events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, error) {
3435
ginRequest, err := g.ProxyEventToHTTPRequest(req)
36+
return g.proxyInternal(ginRequest, err)
37+
}
38+
39+
// ProxyWithContext receives context and an API Gateway proxy event,
40+
// transforms them into an http.Request object, and sends it to the gin.Engine for routing.
41+
// It returns a proxy response object generated from the http.ResponseWriter.
42+
func (g *GinLambda) ProxyWithContext(ctx context.Context, req events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, error) {
43+
ginRequest, err := g.EventToRequestWithContext(ctx, req)
44+
return g.proxyInternal(ginRequest, err)
45+
}
46+
47+
func (g *GinLambda) proxyInternal(req *http.Request, err error) (events.APIGatewayProxyResponse, error) {
3548

3649
if err != nil {
3750
return core.GatewayTimeout(), core.NewLoggedError("Could not convert proxy event to request: %v", err)
3851
}
3952

4053
respWriter := core.NewProxyResponseWriter()
41-
g.ginEngine.ServeHTTP(http.ResponseWriter(respWriter), ginRequest)
54+
g.ginEngine.ServeHTTP(http.ResponseWriter(respWriter), req)
4255

4356
proxyResponse, err := respWriter.GetProxyResponse()
4457
if err != nil {

0 commit comments

Comments
 (0)