Skip to content

Commit a17c9e3

Browse files
committed
Merge branch 'master' of github.com:Saavrm26/LocalAI into resume_download
2 parents 7a46b41 + 20edd44 commit a17c9e3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+633
-114
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ DETECT_LIBS?=true
88
# llama.cpp versions
99
GOLLAMA_REPO?=https://github.com/go-skynet/go-llama.cpp
1010
GOLLAMA_VERSION?=2b57a8ae43e4699d3dc5d1496a1ccd42922993be
11-
CPPLLAMA_VERSION?=9394bbd484f802ce80d2858033583af3ef700d25
11+
CPPLLAMA_VERSION?=53ff6b9b9fb25ed0ec0a213e05534fe7c3d0040f
1212

1313
# whisper.cpp version
1414
WHISPER_REPO?=https://github.com/ggerganov/whisper.cpp

core/http/app.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ func API(application *application.Application) (*fiber.App, error) {
8787

8888
router := fiber.New(fiberCfg)
8989

90+
router.Use(middleware.StripPathPrefix())
91+
9092
router.Hooks().OnListen(func(listenData fiber.ListenData) error {
9193
scheme := "http"
9294
if listenData.TLS {

core/http/app_test.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,31 @@ func postInvalidRequest(url string) (error, int) {
237237
return nil, resp.StatusCode
238238
}
239239

240+
func getRequest(url string, header http.Header) (error, int, []byte) {
241+
242+
req, err := http.NewRequest("GET", url, nil)
243+
if err != nil {
244+
return err, -1, nil
245+
}
246+
247+
req.Header = header
248+
249+
client := &http.Client{}
250+
resp, err := client.Do(req)
251+
if err != nil {
252+
return err, -1, nil
253+
}
254+
255+
defer resp.Body.Close()
256+
257+
body, err := io.ReadAll(resp.Body)
258+
if err != nil {
259+
return err, -1, nil
260+
}
261+
262+
return nil, resp.StatusCode, body
263+
}
264+
240265
const bertEmbeddingsURL = `https://gist.githubusercontent.com/mudler/0a080b166b87640e8644b09c2aee6e3b/raw/f0e8c26bb72edc16d9fbafbfd6638072126ff225/bert-embeddings-gallery.yaml`
241266

242267
//go:embed backend-assets/*
@@ -345,6 +370,33 @@ var _ = Describe("API test", func() {
345370
})
346371
})
347372

373+
Context("URL routing Tests", func() {
374+
It("Should support reverse-proxy when unauthenticated", func() {
375+
376+
err, sc, body := getRequest("http://127.0.0.1:9090/myprefix/", http.Header{
377+
"X-Forwarded-Proto": {"https"},
378+
"X-Forwarded-Host": {"example.org"},
379+
"X-Forwarded-Prefix": {"/myprefix/"},
380+
})
381+
Expect(err).To(BeNil(), "error")
382+
Expect(sc).To(Equal(401), "status code")
383+
Expect(string(body)).To(ContainSubstring(`<base href="https://example.org/myprefix/" />`), "body")
384+
})
385+
386+
It("Should support reverse-proxy when authenticated", func() {
387+
388+
err, sc, body := getRequest("http://127.0.0.1:9090/myprefix/", http.Header{
389+
"Authorization": {bearerKey},
390+
"X-Forwarded-Proto": {"https"},
391+
"X-Forwarded-Host": {"example.org"},
392+
"X-Forwarded-Prefix": {"/myprefix/"},
393+
})
394+
Expect(err).To(BeNil(), "error")
395+
Expect(sc).To(Equal(200), "status code")
396+
Expect(string(body)).To(ContainSubstring(`<base href="https://example.org/myprefix/" />`), "body")
397+
})
398+
})
399+
348400
Context("Applying models", func() {
349401

350402
It("applies models from a gallery", func() {

core/http/elements/buttons.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ func installButton(galleryName string) elem.Node {
1616
"class": "float-right inline-block rounded bg-primary px-6 pb-2.5 mb-3 pt-2.5 text-xs font-medium uppercase leading-normal text-white shadow-primary-3 transition duration-150 ease-in-out hover:bg-primary-accent-300 hover:shadow-primary-2 focus:bg-primary-accent-300 focus:shadow-primary-2 focus:outline-none focus:ring-0 active:bg-primary-600 active:shadow-primary-2 dark:shadow-black/30 dark:hover:shadow-dark-strong dark:focus:shadow-dark-strong dark:active:shadow-dark-strong",
1717
"hx-swap": "outerHTML",
1818
// post the Model ID as param
19-
"hx-post": "/browse/install/model/" + galleryName,
19+
"hx-post": "browse/install/model/" + galleryName,
2020
},
2121
elem.I(
2222
attrs.Props{
@@ -36,7 +36,7 @@ func reInstallButton(galleryName string) elem.Node {
3636
"hx-target": "#action-div-" + dropBadChars(galleryName),
3737
"hx-swap": "outerHTML",
3838
// post the Model ID as param
39-
"hx-post": "/browse/install/model/" + galleryName,
39+
"hx-post": "browse/install/model/" + galleryName,
4040
},
4141
elem.I(
4242
attrs.Props{
@@ -80,7 +80,7 @@ func deleteButton(galleryID string) elem.Node {
8080
"hx-target": "#action-div-" + dropBadChars(galleryID),
8181
"hx-swap": "outerHTML",
8282
// post the Model ID as param
83-
"hx-post": "/browse/delete/model/" + galleryID,
83+
"hx-post": "browse/delete/model/" + galleryID,
8484
},
8585
elem.I(
8686
attrs.Props{

core/http/elements/gallery.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ func searchableElement(text, icon string) elem.Node {
4747
// "value": text,
4848
//"class": "inline-block bg-gray-200 rounded-full px-3 py-1 text-sm font-semibold text-gray-700 mr-2 mb-2",
4949
"href": "#!",
50-
"hx-post": "/browse/search/models",
50+
"hx-post": "browse/search/models",
5151
"hx-target": "#search-results",
5252
// TODO: this doesn't work
5353
// "hx-vals": `{ \"search\": \"` + text + `\" }`,

core/http/elements/progressbar.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ func StartProgressBar(uid, progress, text string) string {
6464
return elem.Div(
6565
attrs.Props{
6666
"hx-trigger": "done",
67-
"hx-get": "/browse/job/" + uid,
67+
"hx-get": "browse/job/" + uid,
6868
"hx-swap": "outerHTML",
6969
"hx-target": "this",
7070
},
@@ -77,7 +77,7 @@ func StartProgressBar(uid, progress, text string) string {
7777
},
7878
elem.Text(bluemonday.StrictPolicy().Sanitize(text)), //Perhaps overly defensive
7979
elem.Div(attrs.Props{
80-
"hx-get": "/browse/job/progress/" + uid,
80+
"hx-get": "browse/job/progress/" + uid,
8181
"hx-trigger": "every 600ms",
8282
"hx-target": "this",
8383
"hx-swap": "innerHTML",

core/http/endpoints/explorer/dashboard.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66

77
"github.com/gofiber/fiber/v2"
88
"github.com/mudler/LocalAI/core/explorer"
9+
"github.com/mudler/LocalAI/core/http/utils"
910
"github.com/mudler/LocalAI/internal"
1011
)
1112

@@ -14,6 +15,7 @@ func Dashboard() func(*fiber.Ctx) error {
1415
summary := fiber.Map{
1516
"Title": "LocalAI API - " + internal.PrintableVersion(),
1617
"Version": internal.PrintableVersion(),
18+
"BaseURL": utils.BaseURL(c),
1719
}
1820

1921
if string(c.Context().Request.Header.ContentType()) == "application/json" || len(c.Accepts("html")) == 0 {

core/http/endpoints/localai/gallery.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"github.com/google/uuid"
1010
"github.com/mudler/LocalAI/core/config"
1111
"github.com/mudler/LocalAI/core/gallery"
12+
"github.com/mudler/LocalAI/core/http/utils"
1213
"github.com/mudler/LocalAI/core/schema"
1314
"github.com/mudler/LocalAI/core/services"
1415
"github.com/rs/zerolog/log"
@@ -82,7 +83,8 @@ func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() func(c *fibe
8283
Galleries: mgs.galleries,
8384
ConfigURL: input.ConfigURL,
8485
}
85-
return c.JSON(schema.GalleryResponse{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()})
86+
87+
return c.JSON(schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", utils.BaseURL(c), uuid.String())})
8688
}
8789
}
8890

@@ -105,7 +107,7 @@ func (mgs *ModelGalleryEndpointService) DeleteModelGalleryEndpoint() func(c *fib
105107
return err
106108
}
107109

108-
return c.JSON(schema.GalleryResponse{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()})
110+
return c.JSON(schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", utils.BaseURL(c), uuid.String())})
109111
}
110112
}
111113

core/http/endpoints/localai/welcome.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"github.com/gofiber/fiber/v2"
55
"github.com/mudler/LocalAI/core/config"
66
"github.com/mudler/LocalAI/core/gallery"
7+
"github.com/mudler/LocalAI/core/http/utils"
78
"github.com/mudler/LocalAI/core/p2p"
89
"github.com/mudler/LocalAI/core/services"
910
"github.com/mudler/LocalAI/internal"
@@ -32,6 +33,7 @@ func WelcomeEndpoint(appConfig *config.ApplicationConfig,
3233
summary := fiber.Map{
3334
"Title": "LocalAI API - " + internal.PrintableVersion(),
3435
"Version": internal.PrintableVersion(),
36+
"BaseURL": utils.BaseURL(c),
3537
"Models": modelsWithoutConfig,
3638
"ModelsConfig": backendConfigs,
3739
"GalleryConfig": galleryConfigs,

core/http/explorer.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"github.com/gofiber/fiber/v2/middleware/favicon"
88
"github.com/gofiber/fiber/v2/middleware/filesystem"
99
"github.com/mudler/LocalAI/core/explorer"
10+
"github.com/mudler/LocalAI/core/http/middleware"
1011
"github.com/mudler/LocalAI/core/http/routes"
1112
)
1213

@@ -22,6 +23,7 @@ func Explorer(db *explorer.Database) *fiber.App {
2223

2324
app := fiber.New(fiberCfg)
2425

26+
app.Use(middleware.StripPathPrefix())
2527
routes.RegisterExplorerRoutes(app, db)
2628

2729
httpFS := http.FS(embedDirStatic)

core/http/middleware/auth.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"github.com/gofiber/fiber/v2"
99
"github.com/gofiber/fiber/v2/middleware/keyauth"
1010
"github.com/mudler/LocalAI/core/config"
11+
"github.com/mudler/LocalAI/core/http/utils"
1112
)
1213

1314
// This file contains the configuration generators and handler functions that are used along with the fiber/keyauth middleware
@@ -39,7 +40,9 @@ func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) fiber.Er
3940
if applicationConfig.OpaqueErrors {
4041
return ctx.SendStatus(401)
4142
}
42-
return ctx.Status(401).Render("views/login", nil)
43+
return ctx.Status(401).Render("views/login", fiber.Map{
44+
"BaseURL": utils.BaseURL(ctx),
45+
})
4346
}
4447
if applicationConfig.OpaqueErrors {
4548
return ctx.SendStatus(500)
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package middleware
2+
3+
import (
4+
"strings"
5+
6+
"github.com/gofiber/fiber/v2"
7+
)
8+
9+
// StripPathPrefix returns a middleware that strips a path prefix from the request path.
10+
// The path prefix is obtained from the X-Forwarded-Prefix HTTP request header.
11+
func StripPathPrefix() fiber.Handler {
12+
return func(c *fiber.Ctx) error {
13+
for _, prefix := range c.GetReqHeaders()["X-Forwarded-Prefix"] {
14+
if prefix != "" {
15+
path := c.Path()
16+
pos := len(prefix)
17+
18+
if prefix[pos-1] == '/' {
19+
pos--
20+
} else {
21+
prefix += "/"
22+
}
23+
24+
if strings.HasPrefix(path, prefix) {
25+
c.Path(path[pos:])
26+
break
27+
} else if prefix[:pos] == path {
28+
c.Redirect(prefix)
29+
return nil
30+
}
31+
}
32+
}
33+
34+
return c.Next()
35+
}
36+
}
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
package middleware
2+
3+
import (
4+
"net/http/httptest"
5+
"testing"
6+
7+
"github.com/gofiber/fiber/v2"
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
func TestStripPathPrefix(t *testing.T) {
12+
var actualPath string
13+
14+
app := fiber.New()
15+
16+
app.Use(StripPathPrefix())
17+
18+
app.Get("/hello/world", func(c *fiber.Ctx) error {
19+
actualPath = c.Path()
20+
return nil
21+
})
22+
23+
app.Get("/", func(c *fiber.Ctx) error {
24+
actualPath = c.Path()
25+
return nil
26+
})
27+
28+
for _, tc := range []struct {
29+
name string
30+
path string
31+
prefixHeader []string
32+
expectStatus int
33+
expectPath string
34+
}{
35+
{
36+
name: "without prefix and header",
37+
path: "/hello/world",
38+
expectStatus: 200,
39+
expectPath: "/hello/world",
40+
},
41+
{
42+
name: "without prefix and headers on root path",
43+
path: "/",
44+
expectStatus: 200,
45+
expectPath: "/",
46+
},
47+
{
48+
name: "without prefix but header",
49+
path: "/hello/world",
50+
prefixHeader: []string{"/otherprefix/"},
51+
expectStatus: 200,
52+
expectPath: "/hello/world",
53+
},
54+
{
55+
name: "with prefix but non-matching header",
56+
path: "/prefix/hello/world",
57+
prefixHeader: []string{"/otherprefix/"},
58+
expectStatus: 404,
59+
},
60+
{
61+
name: "with prefix and matching header",
62+
path: "/myprefix/hello/world",
63+
prefixHeader: []string{"/myprefix/"},
64+
expectStatus: 200,
65+
expectPath: "/hello/world",
66+
},
67+
{
68+
name: "with prefix and 1st header matching",
69+
path: "/myprefix/hello/world",
70+
prefixHeader: []string{"/myprefix/", "/otherprefix/"},
71+
expectStatus: 200,
72+
expectPath: "/hello/world",
73+
},
74+
{
75+
name: "with prefix and 2nd header matching",
76+
path: "/myprefix/hello/world",
77+
prefixHeader: []string{"/otherprefix/", "/myprefix/"},
78+
expectStatus: 200,
79+
expectPath: "/hello/world",
80+
},
81+
{
82+
name: "with prefix and header not ending with slash",
83+
path: "/myprefix/hello/world",
84+
prefixHeader: []string{"/myprefix"},
85+
expectStatus: 200,
86+
expectPath: "/hello/world",
87+
},
88+
{
89+
name: "with prefix and non-matching header not ending with slash",
90+
path: "/myprefix-suffix/hello/world",
91+
prefixHeader: []string{"/myprefix"},
92+
expectStatus: 404,
93+
},
94+
{
95+
name: "redirect when prefix does not end with a slash",
96+
path: "/myprefix",
97+
prefixHeader: []string{"/myprefix"},
98+
expectStatus: 302,
99+
expectPath: "/myprefix/",
100+
},
101+
} {
102+
t.Run(tc.name, func(t *testing.T) {
103+
actualPath = ""
104+
req := httptest.NewRequest("GET", tc.path, nil)
105+
if tc.prefixHeader != nil {
106+
req.Header["X-Forwarded-Prefix"] = tc.prefixHeader
107+
}
108+
109+
resp, err := app.Test(req, -1)
110+
111+
require.NoError(t, err)
112+
require.Equal(t, tc.expectStatus, resp.StatusCode, "response status code")
113+
114+
if tc.expectStatus == 200 {
115+
require.Equal(t, tc.expectPath, actualPath, "rewritten path")
116+
} else if tc.expectStatus == 302 {
117+
require.Equal(t, tc.expectPath, resp.Header.Get("Location"), "redirect location")
118+
}
119+
})
120+
}
121+
}

0 commit comments

Comments
 (0)