diff --git a/internal/http/server.go b/internal/http/server.go index 8b2b238da7266..c7aeed289cc52 100644 --- a/internal/http/server.go +++ b/internal/http/server.go @@ -24,6 +24,7 @@ import ( "os" "runtime" "strconv" + "strings" "time" "go.uber.org/zap" @@ -101,6 +102,8 @@ func registerDefaults() { Path: StaticPath, Handler: GetStaticHandler(), }) + + RegisterWebUIHandler() } func RegisterStopComponent(triggerComponentStop func(role string) error) { @@ -141,12 +144,78 @@ func RegisterCheckComponentReady(checkActive func(role string) error) { w.Write([]byte(`{"msg": "OK"}`)) }, }) - Register(&Handler{ - Path: RouteWebUI, - Handler: http.FileServer(http.FS(staticFiles)), +} + +func RegisterWebUIHandler() { + httpFS := http.FS(staticFiles) + fileServer := http.FileServer(httpFS) + serveIndex := serveFile(RouteWebUI+"index.html", httpFS) + http.Handle(RouteWebUI, handleNotFound(fileServer, serveIndex)) +} + +type responseInterceptor struct { + http.ResponseWriter + is404 bool +} + +func (ri *responseInterceptor) WriteHeader(status int) { + if status == http.StatusNotFound { + ri.is404 = true + return + } + ri.ResponseWriter.WriteHeader(status) +} + +func (ri *responseInterceptor) Write(p []byte) (int, error) { + if ri.is404 { + return len(p), nil // Pretend the data was written for a 404 + } + return ri.ResponseWriter.Write(p) +} + +// handleNotFound attempts to serve a fallback handler (on404) if the main handler returns a 404 status. +func handleNotFound(handler, on404 http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ri := &responseInterceptor{ResponseWriter: w} + handler.ServeHTTP(ri, r) + + if ri.is404 { + on404.ServeHTTP(w, r) + } }) } +// serveFile serves the specified file content (like "index.html") for HTML requests. +func serveFile(filename string, fs http.FileSystem) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if !acceptsHTML(r) { + http.NotFound(w, r) + return + } + + file, err := fs.Open(filename) + if err != nil { + http.NotFound(w, r) + return + } + defer file.Close() + + fi, err := file.Stat() + if err != nil { + http.NotFound(w, r) + return + } + + w.Header().Set("Content-Type", "text/html; charset=utf-8") + http.ServeContent(w, r, fi.Name(), fi.ModTime(), file) + } +} + +// acceptsHTML checks if the request header specifies that HTML is acceptable. +func acceptsHTML(r *http.Request) bool { + return strings.Contains(r.Header.Get("Accept"), "text/html") +} + func Register(h *Handler) { if metricsServer == nil { if paramtable.Get().HTTPCfg.EnablePprof.GetAsBool() { diff --git a/internal/http/server_test.go b/internal/http/server_test.go index d243bf8ac9abf..726a0adb6de0a 100644 --- a/internal/http/server_test.go +++ b/internal/http/server_test.go @@ -24,11 +24,13 @@ import ( "io" "net" "net/http" + "net/http/httptest" "os" "strings" "testing" "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" "go.uber.org/zap" @@ -239,3 +241,96 @@ func (m *MockIndicator) Health(ctx context.Context) commonpb.StateCode { func (m *MockIndicator) GetName() string { return m.name } + +func TestRegisterWebUIHandler(t *testing.T) { + // Initialize the HTTP server + func() { + defer func() { + if err := recover(); err != nil { + fmt.Println("May the handler has been registered!", err) + } + }() + RegisterWebUIHandler() + }() + + // Create a test server + ts := httptest.NewServer(http.DefaultServeMux) + defer ts.Close() + + // Test cases + tests := []struct { + url string + expectedCode int + expectedBody string + }{ + {"/webui/", http.StatusOK, ""}, + {"/webui/index.html", http.StatusOK, ""}, + {"/webui/unknown", http.StatusOK, ""}, + } + + for _, tt := range tests { + t.Run(tt.url, func(t *testing.T) { + req, err := http.NewRequest("GET", ts.URL+tt.url, nil) + assert.NoError(t, err) + req.Header.Set("Accept", "text/html") + resp, err := ts.Client().Do(req) + assert.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, tt.expectedCode, resp.StatusCode) + + body := make([]byte, len(tt.expectedBody)) + _, err = resp.Body.Read(body) + assert.NoError(t, err) + assert.Contains(t, strings.ToLower(string(body)), tt.expectedBody) + }) + } +} + +func TestHandleNotFound(t *testing.T) { + mainHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.NotFound(w, r) + }) + fallbackHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("Fallback")) + }) + + handler := handleNotFound(mainHandler, fallbackHandler) + req := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + resp := w.Result() + body := make([]byte, 8) + resp.Body.Read(body) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "Fallback", string(body)) +} + +func TestServeFile(t *testing.T) { + fs := http.FS(staticFiles) + handler := serveFile("unknown", fs) + + // No Accept in http header + { + req := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + resp := w.Result() + assert.Equal(t, http.StatusNotFound, resp.StatusCode) + } + + // unknown request file + { + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("Accept", "text/html") + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + resp := w.Result() + assert.Equal(t, http.StatusNotFound, resp.StatusCode) + } +}