diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 19f9acc..99af631 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1,4 +1,4 @@ -name: Build Check +name: build on: [push] jobs: @@ -7,20 +7,19 @@ jobs: runs-on: ubuntu-latest steps: - - name: Set up Go 1.12 - uses: actions/setup-go@v1 - with: - go-version: 1.12 - id: go + - name: Set up Go 1.12 + uses: actions/setup-go@v1 + with: + go-version: 1.12 + id: go - - name: Check out code into the Go module directory - uses: actions/checkout@v1 + - name: Check out code into the Go module directory + uses: actions/checkout@v1 - - name: Get dependencies - run: | - go get -v -t -d ./... - - - name: Build - run: go build -v . + - name: Build and Test + # `test` firsts triggers `build` which triggers `clean` + # which in turns triggers `vet` and `fmt` as a series + # of operations + run: make test diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..21b4961 --- /dev/null +++ b/Makefile @@ -0,0 +1,22 @@ +GO = GO111MODULE=on go + +fmt: + ${GO} fmt ./... + +vet: fmt + ${GO} vet ./... + +clean: vet + rm -rf ./bin + ${GO} mod tidy + +build: clean + ${GO} build -o ./bin/bypass-cors ./... + +test: build + ${GO} test -v -cover ./... + +run: clean + ${GO} run ./... -p 80 + +.PHONY: run build clean vet fmt diff --git a/go.mod b/go.mod index c40dedb..cc082ab 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,4 @@ module github.com/Shivam010/bypass-cors go 1.12 -require github.com/rs/cors v1.6.0 +require github.com/google/go-cmp v0.3.1 diff --git a/go.sum b/go.sum index a031710..a6ddb1d 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,2 @@ -github.com/rs/cors v1.6.0 h1:G9tHG9lebljV9mfp9SNPDL36nCDxmo3zTlAf1YgvzmI= -github.com/rs/cors v1.6.0/go.mod h1:gFx+x8UowdsKA9AchylcLynDq+nNFfI8FkUZdN/jGCU= +github.com/google/go-cmp v0.3.1 h1:Xye71clBPdm5HgqGwUkwhbynsUJZhDbS20FvLhQ2izg= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= diff --git a/helper.go b/helper.go index b021f16..ef4ea6d 100644 --- a/helper.go +++ b/helper.go @@ -60,38 +60,53 @@ func (e *Error) StatusCode() int { return e.Code } +const ( + // headers + VaryHeader = "Vary" + OriginHeader = "Origin" + QuoteHeader = "quote" + // Access Control headers + AllowOrigin = "Access-Control-Allow-Origin" + AllowMethods = "Access-Control-Allow-Methods" + AllowHeaders = "Access-Control-Allow-Headers" + AllowCredentials = "Access-Control-Allow-Credentials" + // Access control request headers + RequestMethod = "Access-Control-Request-Method" + RequestHeaders = "Access-Control-Request-Headers" +) + // defaultHeaders handles a general request and add/set corresponding headers func defaultHeaders(w http.ResponseWriter, r *http.Request) { headers := w.Header() - origin := r.Header.Get("Origin") + origin := r.Header.Get(OriginHeader) // Adding Vary header - for http cache - headers.Add("Vary", "Origin") + headers.Add(VaryHeader, OriginHeader) // quote - headers.Set("quote", "Be Happy :)") + headers.Set(QuoteHeader, "Be Happy :)") // Allowing only the requester - can be set to "*" too - headers.Set("Access-Control-Allow-Origin", origin) + headers.Set(AllowOrigin, origin) // Always allowing credentials - just for the sake of proxy request - headers.Set("Access-Control-Allow-Credentials", "true") + headers.Set(AllowCredentials, "true") } // headersForPreflight handles the pre-flight cors request and add/set the // corresponding headers func headersForPreflight(w http.ResponseWriter, r *http.Request) { headers := w.Header() - reqMethod := r.Header.Get("Access-Control-Request-Method") - reqHeaders := r.Header.Get("Access-Control-Request-Headers") + reqMethod := r.Header.Get(RequestMethod) + reqHeaders := r.Header.Get(RequestHeaders) // Vary header - for http cache - headers.Add("Vary", "Access-Control-Request-Method") - headers.Add("Vary", "Access-Control-Request-Headers") + headers.Add(VaryHeader, RequestMethod) + headers.Add(VaryHeader, RequestHeaders) // Allowing the requested method - headers.Set("Access-Control-Allow-Methods", strings.ToUpper(reqMethod)) + headers.Set(AllowMethods, strings.ToUpper(reqMethod)) // Allowing the requested headers - headers.Set("Access-Control-Allow-Headers", reqHeaders) + headers.Set(AllowHeaders, reqHeaders) } // addHeaders handles request and set headers accordingly. It returns true if @@ -100,7 +115,7 @@ func addHeaders(w http.ResponseWriter, r *http.Request) bool { defaultHeaders(w, r) - if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" { + if r.Method == http.MethodOptions && r.Header.Get(RequestMethod) != "" { headersForPreflight(w, r) Return(w, &ValuerStruct{Code: http.StatusOK}) return true @@ -109,7 +124,7 @@ func addHeaders(w http.ResponseWriter, r *http.Request) bool { return false } -// getRequestURL returns the reuested URL to bypass-cors +// getRequestURL returns the requested URL to bypass-cors func getRequestURL(w http.ResponseWriter, r *http.Request) *url.URL { if r.URL.Path == "" || r.URL.Path == "/" { diff --git a/init.go b/init.go new file mode 100644 index 0000000..e6550bb --- /dev/null +++ b/init.go @@ -0,0 +1,24 @@ +package main + +import ( + "flag" +) + +const ( + shortHand = " (short hand)" + + defaultPort = "8080" + usagePort = "PORT at which the server will run" +) + +var ( + // PORT at which the server will run (default: 8080), + // can be modified using flags: + // `-port 80` or `-p 80` + PORT string +) + +func init() { + flag.StringVar(&PORT, "port", defaultPort, usagePort) + flag.StringVar(&PORT, "p", defaultPort, usagePort+shortHand) +} diff --git a/main.go b/main.go index 5cf043a..5348c91 100644 --- a/main.go +++ b/main.go @@ -7,7 +7,6 @@ import ( "io/ioutil" "log" "net/http" - "os" ) type handler struct{} @@ -60,6 +59,11 @@ func (*handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { fmt.Println("UserClient --> bypass-cors -->", req.URL.Host) + // Populate the rest of the header + for k, v := range r.Header { + req.Header.Add(k, v[0]) + } + res, err := http.DefaultClient.Do(req) if err != nil { fmt.Println("Request Failed:", err) @@ -97,14 +101,13 @@ func (*handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } func main() { - var PORT string - if PORT = os.Getenv("PORT"); PORT == "" { - flag.StringVar(&PORT, "p", "8080", "PORT at which the server will run") - } + + // parse all flags set in `init` flag.Parse() - fmt.Printf("\nRunning Proxy ByPass Cors Server at port = %v...\n\n", PORT) + fmt.Printf("\nStarting Proxy ByPass-Cors Server at port(:%v)...\n\n", PORT) + if err := http.ListenAndServe(":"+PORT, &handler{}); err != nil { - log.Println("\n\nPanic", err) + log.Println("\n\nPanics", err) } } diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..b7fc07b --- /dev/null +++ b/main_test.go @@ -0,0 +1,206 @@ +package main + +import ( + "bytes" + "fmt" + "github.com/google/go-cmp/cmp" + "net/http" + "net/http/httptest" + "os" + "testing" +) + +// args - required arguments for any request to serve +type args struct { + w *httptest.ResponseRecorder + r *http.Request + // a dummy server + srv *http.Server +} + +// resChecker - response checker for checking test response +type resChecker struct { + code int + body string + // set noBodyCheck to true if do not want to check body + noBodyCheck bool + // list header keys which should be present in response + headers []string +} + +// defineTest - defines a single unit test +type defineTest struct { + name string + args *args + resChr *resChecker + setup func(*args, *resChecker) +} + +func changeStdOut(s string) *os.File { + tmp := os.Stdout + l, _ := os.Create("./bin/logs_" + s) + os.Stdout = l + return tmp +} + +func resetStdOut(tmp *os.File) { + os.Stdout = tmp +} + +// NOTE: Test_RootRequest should be changed after the error for Root +// request is successfully replaced with the documentation or landing +// page. Follow Issue: Shivam010/bypass-cors#3 +func Test_RootRequest(t *testing.T) { + tmp := changeStdOut(t.Name()) + defer resetStdOut(tmp) + test := defineTest{ + name: "Root Request", + args: &args{ + w: httptest.NewRecorder(), + r: nil, + }, + resChr: &resChecker{}, + setup: func(ar *args, rc *resChecker) { + ar.r, _ = http.NewRequest("GET", "/", &bytes.Buffer{}) + rc.code = http.StatusPreconditionFailed + rc.body = `{"error":{"Code":412,"Message":"URL not provided","Detail":{"method":"GET","requestedURL":"/"}}}` + "\n" + }, + } + t.Run(test.name, func(t *testing.T) { + ha := &handler{} + test.setup(test.args, test.resChr) + ha.ServeHTTP(test.args.w, test.args.r) + if test.args.w.Code != test.resChr.code { + t.Fatalf("Status code mismatched got: %v, want: %v", test.args.w.Code, test.resChr.code) + } + if !test.resChr.noBodyCheck && !cmp.Equal(test.args.w.Body.String(), test.resChr.body) { + t.Fatalf("Body mismatched got: %s, want: %s", test.args.w.Body.String(), test.resChr.body) + } + }) +} + +func Test_Success(t *testing.T) { + tmp := changeStdOut(t.Name()) + defer resetStdOut(tmp) + tests := []defineTest{ + { + name: "GET-Request", + args: &args{ + w: httptest.NewRecorder(), + r: nil, + srv: &http.Server{Addr: ":8181", Handler: http.NotFoundHandler()}, + }, + resChr: &resChecker{ + headers: []string{ + VaryHeader, QuoteHeader, + AllowOrigin, AllowCredentials, + }, + }, + setup: func(ar *args, rc *resChecker) { + ar.srv.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = fmt.Fprintf(w, "Success") + }) + go ar.srv.ListenAndServe() + ar.r, _ = http.NewRequest("GET", "/localhost"+ar.srv.Addr, &bytes.Buffer{}) + rc.code = http.StatusOK + rc.body = fmt.Sprintln("Success") + }, + }, + // TODO: add test for pre-flight request + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + tt.setup(tt.args, tt.resChr) + defer tt.args.srv.Shutdown(nil) + + ha := &handler{} + ha.ServeHTTP(tt.args.w, tt.args.r) + + if tt.args.w.Code != tt.resChr.code { + t.Fatalf("Status code mismatched got: %v, want: %v", tt.args.w.Code, tt.resChr.code) + } + if !tt.resChr.noBodyCheck && !cmp.Equal(tt.args.w.Body.String(), tt.resChr.body) { + t.Fatalf("Body mismatched got: %s, want: %s", tt.args.w.Body.String(), tt.resChr.body) + } + }) + } +} + +func Test_OtherRequests(t *testing.T) { + tmp := changeStdOut(t.Name()) + defer resetStdOut(tmp) + tests := []defineTest{ + { + name: "Can not Process", + args: &args{ + w: httptest.NewRecorder(), + r: nil, + }, + resChr: &resChecker{}, + setup: func(ar *args, rc *resChecker) { + ar.r, _ = http.NewRequest("GET", "/invalid-request", &bytes.Buffer{}) + rc.code = http.StatusUnprocessableEntity + // Note: the error message `Get http://invalid-request: dial tcp: lookup invalid-request: no such host` + // varies from environment to environment and hence, omitting the check + rc.noBodyCheck = true + //rc.body = `{"error":{"Code":422,"Message":"Get http://invalid-request: dial tcp: lookup invalid-request: no such host","Detail":{"body":"","method":"GET","requestedURL":"http://invalid-request","response":null}}}` + "\n" + }, + }, + { + name: "Invalid Request", + args: &args{ + w: httptest.NewRecorder(), + r: nil, + }, + resChr: &resChecker{}, + setup: func(ar *args, rc *resChecker) { + ar.r, _ = http.NewRequest("GET", "", &bytes.Buffer{}) + ar.r.URL.Path = "%invalid%" + rc.code = http.StatusPreconditionFailed + rc.body = `{"error":{"Code":412,"Message":"parse http://invalid%: invalid URL escape \"%\"","Detail":{"method":"GET","requestedURL":"http://invalid%"}}}` + "\n" + }, + }, + { + name: "URL not Provided", + args: &args{ + w: httptest.NewRecorder(), + r: nil, + }, + resChr: &resChecker{}, + setup: func(ar *args, rc *resChecker) { + ar.r, _ = http.NewRequest("GET", "", &bytes.Buffer{}) + rc.code = http.StatusPreconditionFailed + rc.body = `{"error":{"Code":412,"Message":"URL not provided","Detail":{"method":"GET","requestedURL":""}}}` + "\n" + }, + }, + { + name: "Invalid Method", + args: &args{ + w: httptest.NewRecorder(), + r: nil, + }, + resChr: &resChecker{}, + setup: func(ar *args, rc *resChecker) { + ar.r, _ = http.NewRequest("GET", "/localhost", &bytes.Buffer{}) + ar.r.Method += "/" + rc.code = http.StatusPreconditionFailed + rc.body = `{"error":{"Code":412,"Message":"net/http: invalid method \"GET/\"","Detail":{"body":"","method":"GET/","requestedURL":"http://localhost"}}}` + "\n" + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ha := &handler{} + tt.setup(tt.args, tt.resChr) + ha.ServeHTTP(tt.args.w, tt.args.r) + if tt.args.w.Code != tt.resChr.code { + t.Fatalf("Status code mismatched got: %v, want: %v", tt.args.w.Code, tt.resChr.code) + } + if !tt.resChr.noBodyCheck && !cmp.Equal(tt.args.w.Body.String(), tt.resChr.body) { + t.Fatalf("Body mismatched got: %s, want: %s", tt.args.w.Body.String(), tt.resChr.body) + } + }) + } +}