Skip to content

Commit

Permalink
fix: Copy only cookies in proxied requests (#95)
Browse files Browse the repository at this point in the history
  • Loading branch information
nsklikas authored and shipperizer committed Jul 11, 2023
1 parent 89af73b commit 2d1fa6a
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 77 deletions.
32 changes: 13 additions & 19 deletions pkg/kratos/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ func (a *API) handleCreateFlow(w http.ResponseWriter, r *http.Request) {
// to avoid this bug.
session, _, _ := a.service.CheckSession(context.Background(), r.Cookies())
if session != nil {
redirectTo, headers, err := a.service.AcceptLoginRequest(context.Background(), session.Identity.Id, q.Get("login_challenge"))
redirectTo, cookies, err := a.service.AcceptLoginRequest(context.Background(), session.Identity.Id, q.Get("login_challenge"))
if err != nil {
a.logger.Errorf("Error when accepting login request: %v\n", err)
http.Error(w, "Failed to accept login request", http.StatusInternalServerError)
return
}
writeHeaders(w, headers)
setCookies(w, cookies)
resp, err := redirectTo.MarshalJSON()
if err != nil {
a.logger.Errorf("Error when marshalling Json: %v\n", err)
Expand Down Expand Up @@ -72,7 +72,7 @@ func (a *API) handleCreateFlow(w http.ResponseWriter, r *http.Request) {
// We redirect the user back to this endpoint with the login_challenge, after they log in, to bypass
// Kratos bug where the user is not redirected to hydra the first time they log in.
// Relevant issue https://github.com/ory/kratos/issues/3052
flow, headers, err := a.service.CreateBrowserLoginFlow(context.Background(), q.Get("aal"), returnTo, q.Get("login_challenge"), refresh, r.Cookies())
flow, cookies, err := a.service.CreateBrowserLoginFlow(context.Background(), q.Get("aal"), returnTo, q.Get("login_challenge"), refresh, r.Cookies())
if err != nil {
// TODO: Add more context
http.Error(w, "Failed to create login flow", http.StatusInternalServerError)
Expand All @@ -85,7 +85,7 @@ func (a *API) handleCreateFlow(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Failed to marshall json", http.StatusInternalServerError)
return
}
writeHeaders(w, headers)
setCookies(w, cookies)
w.WriteHeader(200)
w.Write(resp)
}
Expand All @@ -94,7 +94,7 @@ func (a *API) handleCreateFlow(w http.ResponseWriter, r *http.Request) {
func (a *API) handleGetLoginFlow(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()

flow, headers, err := a.service.GetLoginFlow(context.Background(), q.Get("id"), r.Cookies())
flow, cookies, err := a.service.GetLoginFlow(context.Background(), q.Get("id"), r.Cookies())
if err != nil {
a.logger.Errorf("Error when getting login flow: %v\n", err)
http.Error(w, "Failed to get login flow", http.StatusInternalServerError)
Expand All @@ -107,7 +107,7 @@ func (a *API) handleGetLoginFlow(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Failed to parse login flow", http.StatusInternalServerError)
return
}
writeHeaders(w, headers)
setCookies(w, cookies)
w.WriteHeader(200)
w.Write(resp)
}
Expand All @@ -123,7 +123,7 @@ func (a *API) handleUpdateFlow(w http.ResponseWriter, r *http.Request) {
return
}

flow, headers, err := a.service.UpdateOIDCLoginFlow(context.Background(), q.Get("flow"), *body, r.Cookies())
flow, cookies, err := a.service.UpdateOIDCLoginFlow(context.Background(), q.Get("flow"), *body, r.Cookies())
if err != nil {
a.logger.Errorf("Error when updating login flow: %v\n", err)
http.Error(w, "Failed to update login flow", http.StatusInternalServerError)
Expand All @@ -136,7 +136,7 @@ func (a *API) handleUpdateFlow(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Failed to parse login flow", http.StatusInternalServerError)
return
}
writeHeaders(w, headers)
setCookies(w, cookies)
w.WriteHeader(422)
w.Write(resp)
}
Expand All @@ -146,7 +146,7 @@ func (a *API) handleKratosError(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
id := q.Get("id")

flowError, headers, err := a.service.GetFlowError(context.Background(), id)
flowError, cookies, err := a.service.GetFlowError(context.Background(), id)
if err != nil {
a.logger.Errorf("Error when getting flow error: %v\n", err)
http.Error(w, "Failed to get flow error", http.StatusInternalServerError)
Expand All @@ -159,7 +159,7 @@ func (a *API) handleKratosError(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Failed to parse flow error", http.StatusInternalServerError)
return
}
writeHeaders(w, headers)
setCookies(w, cookies)
w.WriteHeader(200)
w.Write(resp)
}
Expand All @@ -175,15 +175,9 @@ func NewAPI(service ServiceInterface, baseURL string, logger logging.LoggerInter
return a
}

func writeHeaders(w http.ResponseWriter, headers http.Header) {
excludedHeaders := []string{"Content-Length", "Date"}
for k, vs := range headers {
for _, v := range vs {
w.Header().Set(k, v)
}
}
for _, h := range excludedHeaders {
w.Header().Del(h)
func setCookies(w http.ResponseWriter, cookies []*http.Cookie) {
for _, c := range cookies {
http.SetCookie(w, c)
}
}

Expand Down
8 changes: 4 additions & 4 deletions pkg/kratos/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func TestHandleCreateFlowWithoutSession(t *testing.T) {
req.URL.RawQuery = values.Encode()

mockService.EXPECT().CheckSession(gomock.Any(), req.Cookies()).Return(nil, nil, nil)
mockService.EXPECT().CreateBrowserLoginFlow(gomock.Any(), gomock.Any(), returnTo, loginChallenge, gomock.Any(), req.Cookies()).Return(flow, req.Header, nil)
mockService.EXPECT().CreateBrowserLoginFlow(gomock.Any(), gomock.Any(), returnTo, loginChallenge, gomock.Any(), req.Cookies()).Return(flow, req.Cookies(), nil)

w := httptest.NewRecorder()
mux := chi.NewMux()
Expand Down Expand Up @@ -130,7 +130,7 @@ func TestHandleCreateFlowWithSession(t *testing.T) {
req.URL.RawQuery = values.Encode()

mockService.EXPECT().CheckSession(gomock.Any(), req.Cookies()).Return(session, nil, nil)
mockService.EXPECT().AcceptLoginRequest(gomock.Any(), "test", loginChallenge).Return(redirectTo, req.Header, nil)
mockService.EXPECT().AcceptLoginRequest(gomock.Any(), "test", loginChallenge).Return(redirectTo, req.Cookies(), nil)

w := httptest.NewRecorder()
mux := chi.NewMux()
Expand Down Expand Up @@ -204,7 +204,7 @@ func TestHandleGetLoginFlow(t *testing.T) {
values.Add("id", id)
req.URL.RawQuery = values.Encode()

mockService.EXPECT().GetLoginFlow(gomock.Any(), id, req.Cookies()).Return(flow, req.Header, nil)
mockService.EXPECT().GetLoginFlow(gomock.Any(), id, req.Cookies()).Return(flow, req.Cookies(), nil)

w := httptest.NewRecorder()
mux := chi.NewMux()
Expand Down Expand Up @@ -284,7 +284,7 @@ func TestHandleUpdateFlow(t *testing.T) {
req.URL.RawQuery = values.Encode()

mockService.EXPECT().ParseLoginFlowMethodBody(gomock.Any()).Return(flowBody, nil)
mockService.EXPECT().UpdateOIDCLoginFlow(gomock.Any(), flowId, *flowBody, req.Cookies()).Return(flow, req.Header, nil)
mockService.EXPECT().UpdateOIDCLoginFlow(gomock.Any(), flowId, *flowBody, req.Cookies()).Return(flow, req.Cookies(), nil)

w := httptest.NewRecorder()
mux := chi.NewMux()
Expand Down
12 changes: 6 additions & 6 deletions pkg/kratos/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ type HydraClientInterface interface {
}

type ServiceInterface interface {
CheckSession(context.Context, []*http.Cookie) (*kClient.Session, http.Header, error)
AcceptLoginRequest(context.Context, string, string) (*hClient.OAuth2RedirectTo, http.Header, error)
CreateBrowserLoginFlow(context.Context, string, string, string, bool, []*http.Cookie) (*kClient.LoginFlow, http.Header, error)
GetLoginFlow(context.Context, string, []*http.Cookie) (*kClient.LoginFlow, http.Header, error)
UpdateOIDCLoginFlow(context.Context, string, kClient.UpdateLoginFlowBody, []*http.Cookie) (*ErrorBrowserLocationChangeRequired, http.Header, error)
GetFlowError(context.Context, string) (*kClient.FlowError, http.Header, error)
CheckSession(context.Context, []*http.Cookie) (*kClient.Session, []*http.Cookie, error)
AcceptLoginRequest(context.Context, string, string) (*hClient.OAuth2RedirectTo, []*http.Cookie, error)
CreateBrowserLoginFlow(context.Context, string, string, string, bool, []*http.Cookie) (*kClient.LoginFlow, []*http.Cookie, error)
GetLoginFlow(context.Context, string, []*http.Cookie) (*kClient.LoginFlow, []*http.Cookie, error)
UpdateOIDCLoginFlow(context.Context, string, kClient.UpdateLoginFlowBody, []*http.Cookie) (*ErrorBrowserLocationChangeRequired, []*http.Cookie, error)
GetFlowError(context.Context, string) (*kClient.FlowError, []*http.Cookie, error)
ParseLoginFlowMethodBody(*http.Request) (*kratos_client.UpdateLoginFlowBody, error)
}
24 changes: 12 additions & 12 deletions pkg/kratos/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ type ErrorBrowserLocationChangeRequired struct {
RedirectBrowserTo *string `json:"redirect_browser_to,omitempty"`
}

func (s *Service) CheckSession(ctx context.Context, cookies []*http.Cookie) (*kClient.Session, http.Header, error) {
func (s *Service) CheckSession(ctx context.Context, cookies []*http.Cookie) (*kClient.Session, []*http.Cookie, error) {
_, span := s.tracer.Start(ctx, "kratos.FrontendApi.ToSession")
defer span.End()

Expand All @@ -46,10 +46,10 @@ func (s *Service) CheckSession(ctx context.Context, cookies []*http.Cookie) (*kC

return nil, nil, err
}
return session, resp.Header, nil
return session, resp.Cookies(), nil
}

func (s *Service) AcceptLoginRequest(ctx context.Context, identityID string, lc string) (*hClient.OAuth2RedirectTo, http.Header, error) {
func (s *Service) AcceptLoginRequest(ctx context.Context, identityID string, lc string) (*hClient.OAuth2RedirectTo, []*http.Cookie, error) {
_, span := s.tracer.Start(ctx, "hydra.OAuth2Api.AcceptOAuth2LoginRequest")
defer span.End()

Expand All @@ -66,12 +66,12 @@ func (s *Service) AcceptLoginRequest(ctx context.Context, identityID string, lc
return nil, nil, err
}

return redirectTo, resp.Header, nil
return redirectTo, resp.Cookies(), nil
}

func (s *Service) CreateBrowserLoginFlow(
ctx context.Context, aal, returnTo, loginChallenge string, refresh bool, cookies []*http.Cookie,
) (*kClient.LoginFlow, http.Header, error) {
) (*kClient.LoginFlow, []*http.Cookie, error) {
_, span := s.tracer.Start(ctx, "kratos.FrontendApi.CreateBrowserLoginFlow")
defer span.End()

Expand All @@ -88,10 +88,10 @@ func (s *Service) CreateBrowserLoginFlow(
return nil, nil, err
}

return flow, resp.Header, nil
return flow, resp.Cookies(), nil
}

func (s *Service) GetLoginFlow(ctx context.Context, id string, cookies []*http.Cookie) (*kClient.LoginFlow, http.Header, error) {
func (s *Service) GetLoginFlow(ctx context.Context, id string, cookies []*http.Cookie) (*kClient.LoginFlow, []*http.Cookie, error) {
_, span := s.tracer.Start(ctx, "kratos.FrontendApi.GetLoginFlow")
defer span.End()

Expand All @@ -105,12 +105,12 @@ func (s *Service) GetLoginFlow(ctx context.Context, id string, cookies []*http.C
return nil, nil, err
}

return flow, resp.Header, nil
return flow, resp.Cookies(), nil
}

func (s *Service) UpdateOIDCLoginFlow(
ctx context.Context, flow string, body kClient.UpdateLoginFlowBody, cookies []*http.Cookie,
) (*ErrorBrowserLocationChangeRequired, http.Header, error) {
) (*ErrorBrowserLocationChangeRequired, []*http.Cookie, error) {
_, span := s.tracer.Start(ctx, "kratos.FrontendApi.UpdateLoginFlow")
defer span.End()

Expand All @@ -131,10 +131,10 @@ func (s *Service) UpdateOIDCLoginFlow(
s.logger.Debugf("Failed to unmarshal JSON: %s", err)
return nil, nil, err
}
return redirectResp, resp.Header, nil
return redirectResp, resp.Cookies(), nil
}

func (s *Service) GetFlowError(ctx context.Context, id string) (*kClient.FlowError, http.Header, error) {
func (s *Service) GetFlowError(ctx context.Context, id string) (*kClient.FlowError, []*http.Cookie, error) {
_, span := s.tracer.Start(ctx, "kratos.FrontendApi.GetFlowError")
defer span.End()

Expand All @@ -144,7 +144,7 @@ func (s *Service) GetFlowError(ctx context.Context, id string) (*kClient.FlowErr
return nil, nil, err
}

return flowError, resp.Header, nil
return flowError, resp.Cookies(), nil
}

func (s *Service) ParseLoginFlowMethodBody(r *http.Request) (*kClient.UpdateLoginFlowBody, error) {
Expand Down
Loading

0 comments on commit 2d1fa6a

Please sign in to comment.