diff --git a/docs/index.md b/docs/index.md index 063ab7313..0620155b4 100644 --- a/docs/index.md +++ b/docs/index.md @@ -2424,6 +2424,33 @@ To bring it all together: there are three ways to instruct a `ghttp` server to h When a `ghttp` server receives a request it first checks against the set of handlers registered via `RouteToHandler` if there is no such handler it proceeds to pop an `AppendHandlers` handler off the stack, if the stack of ordered handlers is empty, it will check whether `GetAllowUnhandledRequests` returns `true` or `false`. If `false` the test fails. If `true`, a response is sent with whatever `GetUnhandledRequestStatusCode` returns. +### Using a RoundTripper to route requests to the test Server + +So far you have seen examples of using `server.URL()` to get the string URL of the test server. This is ok if you are testing code where you can pass the URL. In some cases you might need to pass a `http.Client` or similar. + +You can use `server.RounderTripper(nil)` to create a `http.RounderTripper` which will redirect requests to the test server. + +The method takes another `http.RounderTripper` to make the request to the test server, this allows chaining `http.Transports` or otherwise. + +If passed `nil`, then `http.DefaultTransport` is used to make the request. + +```go +Describe("The http client", func() { + var server *ghttp.Server + var httpClient *http.Client + + BeforeEach(func() { + server = ghttp.NewServer() + httpClient = &http.Client{Transport: server.RounderTripper(nil)} + }) + + AfterEach(func() { + //shut down the server between tests + server.Close() + }) +}) +``` + ## `gbytes`: Testing Streaming Buffers `gbytes` implements `gbytes.Buffer` - an `io.WriteCloser` that captures all input to an in-memory buffer. diff --git a/ghttp/test_server.go b/ghttp/test_server.go index 383573bd9..dfde0a427 100644 --- a/ghttp/test_server.go +++ b/ghttp/test_server.go @@ -186,26 +186,26 @@ type Server struct { calls int } -//Start() starts an unstarted ghttp server. It is a catastrophic error to call Start more than once (thanks, httptest). +// Start() starts an unstarted ghttp server. It is a catastrophic error to call Start more than once (thanks, httptest). func (s *Server) Start() { s.HTTPTestServer.Start() } -//URL() returns a url that will hit the server +// URL() returns a url that will hit the server func (s *Server) URL() string { s.rwMutex.RLock() defer s.rwMutex.RUnlock() return s.HTTPTestServer.URL } -//Addr() returns the address on which the server is listening. +// Addr() returns the address on which the server is listening. func (s *Server) Addr() string { s.rwMutex.RLock() defer s.rwMutex.RUnlock() return s.HTTPTestServer.Listener.Addr().String() } -//Close() should be called at the end of each test. It spins down and cleans up the test server. +// Close() should be called at the end of each test. It spins down and cleans up the test server. func (s *Server) Close() { s.rwMutex.Lock() server := s.HTTPTestServer @@ -217,14 +217,14 @@ func (s *Server) Close() { } } -//ServeHTTP() makes Server an http.Handler -//When the server receives a request it handles the request in the following order: +// ServeHTTP() makes Server an http.Handler +// When the server receives a request it handles the request in the following order: // -//1. If the request matches a handler registered with RouteToHandler, that handler is called. -//2. Otherwise, if there are handlers registered via AppendHandlers, those handlers are called in order. -//3. If all registered handlers have been called then: -// a) If AllowUnhandledRequests is set to true, the request will be handled with response code of UnhandledRequestStatusCode -// b) If AllowUnhandledRequests is false, the request will not be handled and the current test will be marked as failed. +// 1. If the request matches a handler registered with RouteToHandler, that handler is called. +// 2. Otherwise, if there are handlers registered via AppendHandlers, those handlers are called in order. +// 3. If all registered handlers have been called then: +// a) If AllowUnhandledRequests is set to true, the request will be handled with response code of UnhandledRequestStatusCode +// b) If AllowUnhandledRequests is false, the request will not be handled and the current test will be marked as failed. func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { s.rwMutex.Lock() defer func() { @@ -280,7 +280,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { } } -//ReceivedRequests is an array containing all requests received by the server (both handled and unhandled requests) +// ReceivedRequests is an array containing all requests received by the server (both handled and unhandled requests) func (s *Server) ReceivedRequests() []*http.Request { s.rwMutex.RLock() defer s.rwMutex.RUnlock() @@ -288,10 +288,10 @@ func (s *Server) ReceivedRequests() []*http.Request { return s.receivedRequests } -//RouteToHandler can be used to register handlers that will always handle requests that match -//the passed in method and path. +// RouteToHandler can be used to register handlers that will always handle requests that match +// the passed in method and path. // -//The path may be either a string object or a *regexp.Regexp. +// The path may be either a string object or a *regexp.Regexp. func (s *Server) RouteToHandler(method string, path interface{}, handler http.HandlerFunc) { s.rwMutex.Lock() defer s.rwMutex.Unlock() @@ -337,7 +337,7 @@ func (s *Server) handlerForRoute(method string, path string) (http.HandlerFunc, return nil, false } -//AppendHandlers will appends http.HandlerFuncs to the server's list of registered handlers. The first incoming request is handled by the first handler, the second by the second, etc... +// AppendHandlers will appends http.HandlerFuncs to the server's list of registered handlers. The first incoming request is handled by the first handler, the second by the second, etc... func (s *Server) AppendHandlers(handlers ...http.HandlerFunc) { s.rwMutex.Lock() defer s.rwMutex.Unlock() @@ -345,9 +345,9 @@ func (s *Server) AppendHandlers(handlers ...http.HandlerFunc) { s.requestHandlers = append(s.requestHandlers, handlers...) } -//SetHandler overrides the registered handler at the passed in index with the passed in handler -//This is useful, for example, when a server has been set up in a shared context, but must be tweaked -//for a particular test. +// SetHandler overrides the registered handler at the passed in index with the passed in handler +// This is useful, for example, when a server has been set up in a shared context, but must be tweaked +// for a particular test. func (s *Server) SetHandler(index int, handler http.HandlerFunc) { s.rwMutex.Lock() defer s.rwMutex.Unlock() @@ -355,7 +355,7 @@ func (s *Server) SetHandler(index int, handler http.HandlerFunc) { s.requestHandlers[index] = handler } -//GetHandler returns the handler registered at the passed in index. +// GetHandler returns the handler registered at the passed in index. func (s *Server) GetHandler(index int) http.HandlerFunc { s.rwMutex.RLock() defer s.rwMutex.RUnlock() @@ -374,12 +374,12 @@ func (s *Server) Reset() { s.routedHandlers = nil } -//WrapHandler combines the passed in handler with the handler registered at the passed in index. -//This is useful, for example, when a server has been set up in a shared context but must be tweaked -//for a particular test. +// WrapHandler combines the passed in handler with the handler registered at the passed in index. +// This is useful, for example, when a server has been set up in a shared context but must be tweaked +// for a particular test. // -//If the currently registered handler is A, and the new passed in handler is B then -//WrapHandler will generate a new handler that first calls A, then calls B, and assign it to index +// If the currently registered handler is A, and the new passed in handler is B then +// WrapHandler will generate a new handler that first calls A, then calls B, and assign it to index func (s *Server) WrapHandler(index int, handler http.HandlerFunc) { existingHandler := s.GetHandler(index) s.SetHandler(index, CombineHandlers(existingHandler, handler)) @@ -392,7 +392,7 @@ func (s *Server) CloseClientConnections() { s.HTTPTestServer.CloseClientConnections() } -//SetAllowUnhandledRequests enables the server to accept unhandled requests. +// SetAllowUnhandledRequests enables the server to accept unhandled requests. func (s *Server) SetAllowUnhandledRequests(allowUnhandledRequests bool) { s.rwMutex.Lock() defer s.rwMutex.Unlock() @@ -400,7 +400,7 @@ func (s *Server) SetAllowUnhandledRequests(allowUnhandledRequests bool) { s.AllowUnhandledRequests = allowUnhandledRequests } -//GetAllowUnhandledRequests returns true if the server accepts unhandled requests. +// GetAllowUnhandledRequests returns true if the server accepts unhandled requests. func (s *Server) GetAllowUnhandledRequests() bool { s.rwMutex.RLock() defer s.rwMutex.RUnlock() @@ -408,7 +408,7 @@ func (s *Server) GetAllowUnhandledRequests() bool { return s.AllowUnhandledRequests } -//SetUnhandledRequestStatusCode status code to be returned when the server receives unhandled requests +// SetUnhandledRequestStatusCode status code to be returned when the server receives unhandled requests func (s *Server) SetUnhandledRequestStatusCode(statusCode int) { s.rwMutex.Lock() defer s.rwMutex.Unlock() @@ -416,10 +416,31 @@ func (s *Server) SetUnhandledRequestStatusCode(statusCode int) { s.UnhandledRequestStatusCode = statusCode } -//GetUnhandledRequestStatusCode returns the current status code being returned for unhandled requests +// GetUnhandledRequestStatusCode returns the current status code being returned for unhandled requests func (s *Server) GetUnhandledRequestStatusCode() int { s.rwMutex.RLock() defer s.rwMutex.RUnlock() return s.UnhandledRequestStatusCode } + +// RoundTripper returns a RoundTripper which updates requests to point to the server. +// This is useful when you want to use the server as a RoundTripper in an http.Client. +// If rt is nil, http.DefaultTransport is used. +func (s *Server) RoundTripper(rt http.RoundTripper) http.RoundTripper { + if rt == nil { + rt = http.DefaultTransport + } + return RoundTripperFunc(func(r *http.Request) (*http.Response, error) { + r.URL.Scheme = "http" + r.URL.Host = s.Addr() + return rt.RoundTrip(r) + }) +} + +// Helper type for creating a RoundTripper from a function +type RoundTripperFunc func(*http.Request) (*http.Response, error) + +func (fn RoundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { + return fn(r) +} diff --git a/ghttp/test_server_test.go b/ghttp/test_server_test.go index abd9678bd..ebd1d779c 100644 --- a/ghttp/test_server_test.go +++ b/ghttp/test_server_test.go @@ -1190,4 +1190,63 @@ var _ = Describe("TestServer", func() { }) }) }) + + Describe("RoundTripper", func() { + var called []string + BeforeEach(func() { + called = []string{} + s.RouteToHandler("GET", "/routed", func(w http.ResponseWriter, req *http.Request) { + called = append(called, "get") + }) + s.RouteToHandler("POST", "/routed", func(w http.ResponseWriter, req *http.Request) { + called = append(called, "post") + }) + }) + + It("should send http traffic to test server with default transport", func() { + client := http.Client{Transport: s.RoundTripper(nil)} + client.Get("http://example.com/routed") + client.Post("http://example.com/routed", "application/json", nil) + client.Get("http://foo.bar/routed") + client.Post("http://foo.bar/routed", "application/json", nil) + Expect(called).Should(Equal([]string{"get", "post", "get", "post"})) + }) + + It("should send https traffic to test server with default transport", func() { + client := http.Client{Transport: s.RoundTripper(nil)} + client.Get("https://example.com/routed") + client.Post("https://example.com/routed", "application/json", nil) + client.Get("https://foo.bar/routed") + client.Post("https://foo.bar/routed", "application/json", nil) + Expect(called).Should(Equal([]string{"get", "post", "get", "post"})) + }) + + It("should send http traffic to test server with default transport", func() { + transport := http.Transport{} + client := http.Client{Transport: s.RoundTripper(&transport)} + client.Get("http://example.com/routed") + client.Post("http://example.com/routed", "application/json", nil) + client.Get("http://foo.bar/routed") + client.Post("http://foo.bar/routed", "application/json", nil) + Expect(called).Should(Equal([]string{"get", "post", "get", "post"})) + }) + + It("should send http traffic to test server with default transport", func() { + transport := http.Transport{} + client := http.Client{Transport: s.RoundTripper(&transport)} + client.Get("https://example.com/routed") + client.Post("https://example.com/routed", "application/json", nil) + client.Get("https://foo.bar/routed") + client.Post("https://foo.bar/routed", "application/json", nil) + Expect(called).Should(Equal([]string{"get", "post", "get", "post"})) + }) + + It("should not change the path of the request", func() { + client := http.Client{Transport: s.RoundTripper(nil)} + client.Get("https://example.com/routed") + Expect(called).Should(Equal([]string{"get"})) + Expect(s.ReceivedRequests()).Should(HaveLen(1)) + Expect(s.ReceivedRequests()[0].URL.Path).Should(Equal("/routed")) + }) + }) })