Skip to content

Commit a8c1ace

Browse files
committed
feat: IP restriction
ie: ''' LOCALAI_IP_ALLOWLIST=192.168.1.0/24,10.0.0.1,127.0.0.1 '''
1 parent e905e90 commit a8c1ace

File tree

6 files changed

+176
-0
lines changed

6 files changed

+176
-0
lines changed

.env

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,7 @@
100100
#
101101
# Time in duration format (e.g. 1h30m) after which a backend is considered busy
102102
# LOCALAI_WATCHDOG_BUSY_TIMEOUT=5m
103+
104+
# allowed access ip config, ie: 192.168.1.0/24,10.0.0.1,127.0.0.1
105+
# export LOCALAI_IP_ALLOWLIST="192.168.1.0/24,10.0.0.1,127.0.0.1"
106+
# LOCALAI_IP_ALLOWLIST=192.168.1.0/24

core/cli/run.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ type RunCMD struct {
4646
ContextSize int `env:"LOCALAI_CONTEXT_SIZE,CONTEXT_SIZE" help:"Default context size for models" group:"performance"`
4747

4848
Address string `env:"LOCALAI_ADDRESS,ADDRESS" default:":8080" help:"Bind address for the API server" group:"api"`
49+
IpAllowList string `env:"LOCALAI_IP_ALLOWLIST,IP_ALLOWLIST" help:"A list of IP addresses or CIDR ranges to allow access" group:"api"`
4950
CORS bool `env:"LOCALAI_CORS,CORS" help:"" group:"api"`
5051
CORSAllowOrigins string `env:"LOCALAI_CORS_ALLOW_ORIGINS,CORS_ALLOW_ORIGINS" group:"api"`
5152
CSRF bool `env:"LOCALAI_CSRF" help:"Enables fiber CSRF middleware" group:"api"`
@@ -127,6 +128,7 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
127128
config.WithP2PNetworkID(r.Peer2PeerNetworkID),
128129
config.WithLoadToMemory(r.LoadToMemory),
129130
config.WithMachineTag(r.MachineTag),
131+
config.WithIPAllowList(r.IpAllowList),
130132
}
131133

132134
if r.DisableMetricsEndpoint {

core/config/application_config.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"regexp"
77
"time"
88

9+
"github.com/mudler/LocalAI/core/http/utils"
910
"github.com/mudler/LocalAI/pkg/system"
1011
"github.com/mudler/LocalAI/pkg/xsysinfo"
1112
"github.com/rs/zerolog/log"
@@ -63,6 +64,11 @@ type ApplicationConfig struct {
6364
WatchDogBusyTimeout, WatchDogIdleTimeout time.Duration
6465

6566
MachineTag string
67+
68+
// ie: 192.168.1.0/24,10.0.0.1,127.0.0.1
69+
IpAllowList string
70+
71+
IPAllowListHelper *utils.IPAllowList
6672
}
6773

6874
type AppOption func(*ApplicationConfig)
@@ -128,6 +134,15 @@ func WithP2PToken(s string) AppOption {
128134
}
129135
}
130136

137+
func WithIPAllowList(s string) AppOption {
138+
return func(o *ApplicationConfig) {
139+
log.Info().Msgf("Application IpAllowList($LOCALAI_IP_ALLOWLIST): %s", s)
140+
o.IpAllowList = s
141+
var ipAllowListHelper, _ = utils.NewIPAllowList(s)
142+
o.IPAllowListHelper = ipAllowListHelper
143+
}
144+
}
145+
131146
var EnableWatchDog = func(o *ApplicationConfig) {
132147
o.WatchDog = true
133148
}

core/http/app.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,17 @@ func API(application *application.Application) (*fiber.App, error) {
128128
router.Use(recover.New())
129129
}
130130

131+
//IP restriction
132+
router.Use(func(c *fiber.Ctx) error {
133+
clientIP := c.IP()
134+
if application.ApplicationConfig().IPAllowListHelper.IsAllowed(clientIP) {
135+
return c.Next()
136+
}
137+
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
138+
"error": "Forbidden: your IP is not allowed",
139+
})
140+
})
141+
131142
if !application.ApplicationConfig().DisableMetrics {
132143
metricsService, err := services.NewLocalAIMetricsService()
133144
if err != nil {

core/http/utils/IPAllowList.go

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
package utils
2+
3+
import (
4+
"fmt"
5+
"net"
6+
"net/netip"
7+
"strings"
8+
"sync"
9+
)
10+
11+
type IPAllowList struct {
12+
allowList string
13+
cidrs []*net.IPNet
14+
ips []net.IP
15+
mu sync.RWMutex
16+
enabled bool
17+
}
18+
19+
func NewIPAllowList(allowList string) (*IPAllowList, error) {
20+
21+
w := &IPAllowList{}
22+
err := w.Update(allowList)
23+
return w, err
24+
}
25+
26+
func (w *IPAllowList) GetAllowList() string {
27+
return w.allowList
28+
}
29+
30+
func (w *IPAllowList) Update(allowListStr string) error {
31+
var cidrs []*net.IPNet
32+
var ips []net.IP
33+
34+
allowList := make([]string, 0)
35+
if allowListStr != "" {
36+
allowList = strings.Split(allowListStr, ",")
37+
}
38+
39+
for _, item := range allowList {
40+
_, cidrNet, err := net.ParseCIDR(item)
41+
if err == nil {
42+
cidrs = append(cidrs, cidrNet)
43+
} else {
44+
ip := net.ParseIP(item)
45+
if ip != nil {
46+
ips = append(ips, ip)
47+
} else {
48+
return fmt.Errorf("invalid allowList item: %s", item)
49+
}
50+
}
51+
}
52+
53+
w.mu.Lock()
54+
defer w.mu.Unlock()
55+
w.allowList = allowListStr
56+
w.cidrs = cidrs
57+
w.ips = ips
58+
w.enabled = len(cidrs) > 0 || len(ips) > 0
59+
return nil
60+
}
61+
62+
func (w *IPAllowList) IsAllowed(ip interface{}) bool {
63+
if !w.enabled {
64+
return true
65+
}
66+
67+
var parsedIP net.IP
68+
switch v := ip.(type) {
69+
case string:
70+
parsedIP = net.ParseIP(v)
71+
case net.IP:
72+
parsedIP = v
73+
case netip.Addr:
74+
parsedIP = net.IP(v.AsSlice())
75+
default:
76+
if str, ok := v.(string); ok {
77+
parsedIP = net.ParseIP(str)
78+
}
79+
}
80+
81+
if parsedIP == nil {
82+
return false
83+
}
84+
85+
w.mu.RLock()
86+
defer w.mu.RUnlock()
87+
88+
for _, cidr := range w.cidrs {
89+
if cidr.Contains(parsedIP) {
90+
return true
91+
}
92+
}
93+
94+
for _, allowedIP := range w.ips {
95+
if parsedIP.Equal(allowedIP) {
96+
return true
97+
}
98+
}
99+
return false
100+
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package utils
2+
3+
import (
4+
"fmt"
5+
"testing"
6+
)
7+
8+
func TestIPAllowList(t *testing.T) {
9+
// Test empty AllowList (no restrictions)
10+
w, err := NewIPAllowList("")
11+
12+
if err != nil {
13+
t.Fatalf("Expected no error for empty AllowList, got: %v", err)
14+
}
15+
if !w.IsAllowed("192.168.1.100") {
16+
t.Error("Empty AllowList should allow all IPs")
17+
}
18+
19+
// Test valid AllowList
20+
AllowList := "192.168.1.0/24,10.0.0.1,127.0.0.1"
21+
w, err = NewIPAllowList(AllowList)
22+
if err != nil {
23+
t.Fatalf("Failed to create IP AllowList: %v", err)
24+
}
25+
26+
tests := []struct {
27+
ip string
28+
expected bool
29+
}{
30+
{"192.168.1.100", true},
31+
{"10.0.0.1", true},
32+
{"127.0.0.1", true},
33+
{"10.0.0.2", false},
34+
{"172.16.0.1", false},
35+
}
36+
37+
for _, tc := range tests {
38+
t.Run(fmt.Sprintf("IP: %s", tc.ip), func(t *testing.T) {
39+
if got := w.IsAllowed(tc.ip); got != tc.expected {
40+
t.Errorf("isAllowedIP(%q) = %v, want %v", tc.ip, got, tc.expected)
41+
}
42+
})
43+
}
44+
}

0 commit comments

Comments
 (0)