Skip to content

Commit 810b261

Browse files
committed
implement preserving env from host into vm in shell command
Signed-off-by: olalekan odukoya <odukoyaonline@gmail.com>
1 parent 96c7179 commit 810b261

File tree

5 files changed

+372
-14
lines changed

5 files changed

+372
-14
lines changed

cmd/limactl/shell.go

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"github.com/sirupsen/logrus"
2020
"github.com/spf13/cobra"
2121

22+
"github.com/lima-vm/lima/v2/pkg/envutil"
2223
"github.com/lima-vm/lima/v2/pkg/instance"
2324
"github.com/lima-vm/lima/v2/pkg/ioutilx"
2425
"github.com/lima-vm/lima/v2/pkg/limayaml"
@@ -54,6 +55,7 @@ func newShellCommand() *cobra.Command {
5455
shellCmd.Flags().String("shell", "", "Shell interpreter, e.g. /bin/bash")
5556
shellCmd.Flags().String("workdir", "", "Working directory")
5657
shellCmd.Flags().Bool("reconnect", false, "Reconnect to the SSH session")
58+
shellCmd.Flags().Bool("preserve-env", false, "Propagate environment variables to the shell")
5759
return shellCmd
5860
}
5961

@@ -178,7 +180,25 @@ func shellAction(cmd *cobra.Command, args []string) error {
178180
} else {
179181
shell = shellescape.Quote(shell)
180182
}
181-
script := fmt.Sprintf("%s ; exec %s --login", changeDirCmd, shell)
183+
// Handle environment variable propagation
184+
var envPrefix string
185+
withEnv, err := cmd.Flags().GetBool("preserve-env")
186+
if err != nil {
187+
return err
188+
}
189+
if withEnv {
190+
config := envutil.GetFilterConfig()
191+
filteredEnv := envutil.FilterEnvironment(config)
192+
if len(filteredEnv) > 0 {
193+
envVars := make([]string, len(filteredEnv))
194+
for i, envVar := range filteredEnv {
195+
envVars[i] = shellescape.Quote(envVar)
196+
}
197+
envPrefix = "env " + strings.Join(envVars, " ") + " "
198+
}
199+
}
200+
201+
script := fmt.Sprintf("%s ; exec %s%s --login", changeDirCmd, envPrefix, shell)
182202
if len(args) > 1 {
183203
quotedArgs := make([]string, len(args[1:]))
184204
parsingEnv := true

cmd/nerdctl.lima

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
#!/bin/sh
22
set -eu
3-
exec lima nerdctl "$@"
3+
exec limactl shell --preserve-env default nerdctl "$@"

pkg/envutil/envutil.go

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
// SPDX-FileCopyrightText: Copyright The Lima Authors
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package envutil
5+
6+
import (
7+
"os"
8+
"strings"
9+
)
10+
11+
// BuiltinBlocklist contains environment variables that should not be propagated by default.
12+
var BuiltinBlocklist = []string{
13+
"TMPDIR",
14+
"PATH",
15+
"HOME",
16+
"USER",
17+
"LOGNAME",
18+
"SHELL",
19+
"PWD",
20+
"OLDPWD",
21+
"TERM",
22+
"TERMINFO",
23+
"SSH_*",
24+
"DISPLAY",
25+
"XAUTHORITY",
26+
"XDG_*",
27+
"_*", // Variables starting with underscore are typically internal
28+
}
29+
30+
type FilterConfig struct {
31+
BlockList []string
32+
AllowList []string
33+
UseBuiltinBlocklist bool
34+
}
35+
36+
func GetFilterConfig() *FilterConfig {
37+
config := &FilterConfig{
38+
UseBuiltinBlocklist: true,
39+
}
40+
41+
if blockEnv := os.Getenv("LIMA_SHELLENV_BLOCK"); blockEnv != "" {
42+
if strings.HasPrefix(blockEnv, "+") {
43+
config.BlockList = append(BuiltinBlocklist, parseEnvList(blockEnv[1:])...)
44+
} else {
45+
config.BlockList = parseEnvList(blockEnv)
46+
config.UseBuiltinBlocklist = false
47+
}
48+
} else {
49+
config.BlockList = BuiltinBlocklist
50+
}
51+
52+
if allowEnv := os.Getenv("LIMA_SHELLENV_ALLOW"); allowEnv != "" {
53+
config.AllowList = parseEnvList(allowEnv)
54+
}
55+
56+
return config
57+
}
58+
59+
func parseEnvList(envList string) []string {
60+
if envList == "" {
61+
return nil
62+
}
63+
64+
parts := strings.Split(envList, ",")
65+
result := make([]string, 0, len(parts))
66+
for _, part := range parts {
67+
if trimmed := strings.TrimSpace(part); trimmed != "" {
68+
result = append(result, trimmed)
69+
}
70+
}
71+
72+
if len(result) == 0 {
73+
return nil
74+
}
75+
76+
return result
77+
}
78+
79+
func matchesPattern(name, pattern string) bool {
80+
if pattern == name {
81+
return true
82+
}
83+
84+
if strings.HasSuffix(pattern, "*") {
85+
prefix := pattern[:len(pattern)-1]
86+
return strings.HasPrefix(name, prefix)
87+
}
88+
89+
return false
90+
}
91+
92+
func isBlocked(name string, patterns []string) bool {
93+
for _, pattern := range patterns {
94+
if matchesPattern(name, pattern) {
95+
return true
96+
}
97+
}
98+
return false
99+
}
100+
101+
func FilterEnvironment(config *FilterConfig) []string {
102+
env := os.Environ()
103+
var filtered []string
104+
105+
for _, envVar := range env {
106+
parts := strings.SplitN(envVar, "=", 2)
107+
if len(parts) != 2 {
108+
continue
109+
}
110+
111+
name := parts[0]
112+
113+
if len(config.AllowList) > 0 {
114+
if !isBlocked(name, config.AllowList) {
115+
continue
116+
}
117+
filtered = append(filtered, envVar)
118+
continue
119+
}
120+
121+
if isBlocked(name, config.BlockList) {
122+
continue
123+
}
124+
125+
filtered = append(filtered, envVar)
126+
}
127+
128+
return filtered
129+
}
130+
131+
func GetBuiltinBlocklist() []string {
132+
result := make([]string, len(BuiltinBlocklist))
133+
copy(result, BuiltinBlocklist)
134+
return result
135+
}

pkg/envutil/envutil_test.go

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
// SPDX-FileCopyrightText: Copyright The Lima Authors
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package envutil
5+
6+
import (
7+
"os"
8+
"reflect"
9+
"testing"
10+
)
11+
12+
func TestMatchesPattern(t *testing.T) {
13+
tests := []struct {
14+
name string
15+
pattern string
16+
expected bool
17+
}{
18+
{"PATH", "PATH", true},
19+
{"PATH", "HOME", false},
20+
{"SSH_AUTH_SOCK", "SSH_*", true},
21+
{"SSH_AGENT_PID", "SSH_*", true},
22+
{"HOME", "SSH_*", false},
23+
{"XDG_CONFIG_HOME", "XDG_*", true},
24+
{"_LIMA_TEST", "_*", true},
25+
{"LIMA_HOME", "_*", false},
26+
}
27+
28+
for _, tt := range tests {
29+
t.Run(tt.name+"_matches_"+tt.pattern, func(t *testing.T) {
30+
result := matchesPattern(tt.name, tt.pattern)
31+
if result != tt.expected {
32+
t.Errorf("matchesPattern(%q, %q) = %v, want %v", tt.name, tt.pattern, result, tt.expected)
33+
}
34+
})
35+
}
36+
}
37+
38+
func TestIsBlocked(t *testing.T) {
39+
patterns := []string{"PATH", "SSH_*", "XDG_*"}
40+
41+
tests := []struct {
42+
name string
43+
expected bool
44+
}{
45+
{"PATH", true},
46+
{"HOME", false},
47+
{"SSH_AUTH_SOCK", true},
48+
{"XDG_CONFIG_HOME", true},
49+
{"USER", false},
50+
}
51+
52+
for _, tt := range tests {
53+
t.Run(tt.name, func(t *testing.T) {
54+
result := isBlocked(tt.name, patterns)
55+
if result != tt.expected {
56+
t.Errorf("isBlocked(%q, %v) = %v, want %v", tt.name, patterns, result, tt.expected)
57+
}
58+
})
59+
}
60+
}
61+
62+
func TestParseEnvList(t *testing.T) {
63+
tests := []struct {
64+
input string
65+
expected []string
66+
}{
67+
{"", nil},
68+
{"PATH", []string{"PATH"}},
69+
{"PATH,HOME", []string{"PATH", "HOME"}},
70+
{"PATH, HOME , USER", []string{"PATH", "HOME", "USER"}},
71+
{" , , ", nil},
72+
}
73+
74+
for _, tt := range tests {
75+
t.Run(tt.input, func(t *testing.T) {
76+
result := parseEnvList(tt.input)
77+
if !reflect.DeepEqual(result, tt.expected) {
78+
t.Errorf("parseEnvList(%q) = %v, want %v", tt.input, result, tt.expected)
79+
}
80+
})
81+
}
82+
}
83+
84+
func TestGetFilterConfig(t *testing.T) {
85+
originalBlock := os.Getenv("LIMA_SHELLENV_BLOCK")
86+
originalAllow := os.Getenv("LIMA_SHELLENV_ALLOW")
87+
defer func() {
88+
os.Setenv("LIMA_SHELLENV_BLOCK", originalBlock)
89+
os.Setenv("LIMA_SHELLENV_ALLOW", originalAllow)
90+
}()
91+
92+
t.Run("default config", func(t *testing.T) {
93+
os.Unsetenv("LIMA_SHELLENV_BLOCK")
94+
os.Unsetenv("LIMA_SHELLENV_ALLOW")
95+
96+
config := GetFilterConfig()
97+
if !config.UseBuiltinBlocklist {
98+
t.Error("Expected UseBuiltinBlocklist to be true")
99+
}
100+
if !reflect.DeepEqual(config.BlockList, BuiltinBlocklist) {
101+
t.Error("Expected BlockList to equal BuiltinBlocklist")
102+
}
103+
if len(config.AllowList) != 0 {
104+
t.Error("Expected AllowList to be empty")
105+
}
106+
})
107+
108+
t.Run("custom blocklist", func(t *testing.T) {
109+
t.Setenv("LIMA_SHELLENV_BLOCK", "PATH,HOME")
110+
os.Unsetenv("LIMA_SHELLENV_ALLOW")
111+
112+
config := GetFilterConfig()
113+
if config.UseBuiltinBlocklist {
114+
t.Error("Expected UseBuiltinBlocklist to be false")
115+
}
116+
expected := []string{"PATH", "HOME"}
117+
if !reflect.DeepEqual(config.BlockList, expected) {
118+
t.Errorf("Expected BlockList to be %v, got %v", expected, config.BlockList)
119+
}
120+
})
121+
122+
t.Run("additive blocklist", func(t *testing.T) {
123+
t.Setenv("LIMA_SHELLENV_BLOCK", "+CUSTOM_VAR")
124+
os.Unsetenv("LIMA_SHELLENV_ALLOW")
125+
126+
config := GetFilterConfig()
127+
if !config.UseBuiltinBlocklist {
128+
t.Error("Expected UseBuiltinBlocklist to be true")
129+
}
130+
expected := make([]string, len(BuiltinBlocklist)+1)
131+
copy(expected, BuiltinBlocklist)
132+
expected[len(BuiltinBlocklist)] = "CUSTOM_VAR"
133+
if !reflect.DeepEqual(config.BlockList, expected) {
134+
t.Errorf("Expected BlockList to include builtin + custom, got %v", config.BlockList)
135+
}
136+
})
137+
138+
t.Run("allowlist", func(t *testing.T) {
139+
os.Unsetenv("LIMA_SHELLENV_BLOCK")
140+
t.Setenv("LIMA_SHELLENV_ALLOW", "FOO,BAR")
141+
142+
config := GetFilterConfig()
143+
expected := []string{"FOO", "BAR"}
144+
if !reflect.DeepEqual(config.AllowList, expected) {
145+
t.Errorf("Expected AllowList to be %v, got %v", expected, config.AllowList)
146+
}
147+
})
148+
}
149+
150+
func TestFilterEnvironment(t *testing.T) {
151+
testEnv := []string{
152+
"PATH=/usr/bin",
153+
"HOME=/home/user",
154+
"USER=testuser",
155+
"FOO=bar",
156+
"SSH_AUTH_SOCK=/tmp/ssh",
157+
"XDG_CONFIG_HOME=/config",
158+
}
159+
160+
originalEnviron := os.Environ
161+
defer func() {
162+
_ = originalEnviron
163+
}()
164+
165+
t.Run("default blocklist", func(_ *testing.T) {
166+
config := &FilterConfig{
167+
BlockList: BuiltinBlocklist,
168+
UseBuiltinBlocklist: true,
169+
}
170+
171+
_ = config
172+
_ = testEnv
173+
})
174+
}
175+
176+
func TestGetBuiltinBlocklist(t *testing.T) {
177+
blocklist := GetBuiltinBlocklist()
178+
179+
if &blocklist[0] == &BuiltinBlocklist[0] {
180+
t.Error("GetBuiltinBlocklist should return a copy, not the original slice")
181+
}
182+
183+
if !reflect.DeepEqual(blocklist, BuiltinBlocklist) {
184+
t.Error("GetBuiltinBlocklist should return the same content as BuiltinBlocklist")
185+
}
186+
187+
expectedItems := []string{"PATH", "HOME", "SSH_*"}
188+
for _, item := range expectedItems {
189+
found := false
190+
for _, blocked := range blocklist {
191+
if blocked == item {
192+
found = true
193+
break
194+
}
195+
}
196+
if !found {
197+
t.Errorf("Expected builtin blocklist to contain %q", item)
198+
}
199+
}
200+
}

0 commit comments

Comments
 (0)