Skip to content

Commit d9e9da9

Browse files
committed
implement preserving env from host into vm in shell command
Signed-off-by: olalekan odukoya <odukoyaonline@gmail.com>
1 parent 96c7179 commit d9e9da9

File tree

5 files changed

+379
-14
lines changed

5 files changed

+379
-14
lines changed

cmd/limactl/shell.go

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"github.com/sirupsen/logrus"
2020
"github.com/spf13/cobra"
2121

22+
"github.com/lima-vm/lima/v2/pkg/envutil"
2223
"github.com/lima-vm/lima/v2/pkg/instance"
2324
"github.com/lima-vm/lima/v2/pkg/ioutilx"
2425
"github.com/lima-vm/lima/v2/pkg/limayaml"
@@ -54,6 +55,7 @@ func newShellCommand() *cobra.Command {
5455
shellCmd.Flags().String("shell", "", "Shell interpreter, e.g. /bin/bash")
5556
shellCmd.Flags().String("workdir", "", "Working directory")
5657
shellCmd.Flags().Bool("reconnect", false, "Reconnect to the SSH session")
58+
shellCmd.Flags().Bool("preserve-env", false, "Propagate environment variables to the shell")
5759
return shellCmd
5860
}
5961

@@ -178,7 +180,25 @@ func shellAction(cmd *cobra.Command, args []string) error {
178180
} else {
179181
shell = shellescape.Quote(shell)
180182
}
181-
script := fmt.Sprintf("%s ; exec %s --login", changeDirCmd, shell)
183+
// Handle environment variable propagation
184+
var envPrefix string
185+
withEnv, err := cmd.Flags().GetBool("preserve-env")
186+
if err != nil {
187+
return err
188+
}
189+
if withEnv {
190+
config := envutil.GetFilterConfig()
191+
filteredEnv := envutil.FilterEnvironment(config)
192+
if len(filteredEnv) > 0 {
193+
envVars := make([]string, len(filteredEnv))
194+
for i, envVar := range filteredEnv {
195+
envVars[i] = shellescape.Quote(envVar)
196+
}
197+
envPrefix = "env " + strings.Join(envVars, " ") + " "
198+
}
199+
}
200+
201+
script := fmt.Sprintf("%s ; exec %s%s --login", changeDirCmd, envPrefix, shell)
182202
if len(args) > 1 {
183203
quotedArgs := make([]string, len(args[1:]))
184204
parsingEnv := true

cmd/nerdctl.lima

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
#!/bin/sh
22
set -eu
3-
exec lima nerdctl "$@"
3+
exec limactl shell --preserve-env default nerdctl "$@"

pkg/envutil/envutil.go

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
// SPDX-FileCopyrightText: Copyright The Lima Authors
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package envutil
5+
6+
import (
7+
"os"
8+
"strings"
9+
)
10+
11+
// BuiltinBlocklist contains environment variables that should not be propagated by default.
12+
var BuiltinBlocklist = []string{
13+
"TMPDIR",
14+
"PATH",
15+
"HOME",
16+
"USER",
17+
"LOGNAME",
18+
"SHELL",
19+
"PWD",
20+
"OLDPWD",
21+
"TERM",
22+
"TERMINFO",
23+
"SSH_*",
24+
"DISPLAY",
25+
"XAUTHORITY",
26+
"XDG_*",
27+
"_*", // Variables starting with underscore are typically internal
28+
}
29+
30+
type FilterConfig struct {
31+
BlockList []string
32+
AllowList []string
33+
UseBuiltinBlocklist bool
34+
}
35+
36+
func GetFilterConfig() *FilterConfig {
37+
config := &FilterConfig{
38+
UseBuiltinBlocklist: true,
39+
}
40+
41+
if blockEnv := os.Getenv("LIMA_SHELLENV_BLOCK"); blockEnv != "" {
42+
if strings.HasPrefix(blockEnv, "+") {
43+
additionalBlocks := parseEnvList(blockEnv[1:])
44+
config.BlockList = make([]string, len(BuiltinBlocklist)+len(additionalBlocks))
45+
copy(config.BlockList, BuiltinBlocklist)
46+
copy(config.BlockList[len(BuiltinBlocklist):], additionalBlocks)
47+
} else {
48+
config.BlockList = parseEnvList(blockEnv)
49+
config.UseBuiltinBlocklist = false
50+
}
51+
} else {
52+
config.BlockList = BuiltinBlocklist
53+
}
54+
55+
if allowEnv := os.Getenv("LIMA_SHELLENV_ALLOW"); allowEnv != "" {
56+
config.AllowList = parseEnvList(allowEnv)
57+
}
58+
59+
return config
60+
}
61+
62+
func parseEnvList(envList string) []string {
63+
if envList == "" {
64+
return nil
65+
}
66+
67+
parts := strings.Split(envList, ",")
68+
result := make([]string, 0, len(parts))
69+
for _, part := range parts {
70+
if trimmed := strings.TrimSpace(part); trimmed != "" {
71+
result = append(result, trimmed)
72+
}
73+
}
74+
75+
if len(result) == 0 {
76+
return nil
77+
}
78+
79+
return result
80+
}
81+
82+
func matchesPattern(name, pattern string) bool {
83+
if pattern == name {
84+
return true
85+
}
86+
87+
if strings.HasSuffix(pattern, "*") {
88+
prefix := pattern[:len(pattern)-1]
89+
return strings.HasPrefix(name, prefix)
90+
}
91+
92+
return false
93+
}
94+
95+
func isBlocked(name string, patterns []string) bool {
96+
for _, pattern := range patterns {
97+
if matchesPattern(name, pattern) {
98+
return true
99+
}
100+
}
101+
return false
102+
}
103+
104+
func FilterEnvironment(config *FilterConfig) []string {
105+
env := os.Environ()
106+
var filtered []string
107+
108+
for _, envVar := range env {
109+
parts := strings.SplitN(envVar, "=", 2)
110+
if len(parts) != 2 {
111+
continue
112+
}
113+
114+
name := parts[0]
115+
116+
if len(config.AllowList) > 0 {
117+
if !isBlocked(name, config.AllowList) {
118+
continue
119+
}
120+
filtered = append(filtered, envVar)
121+
continue
122+
}
123+
124+
if isBlocked(name, config.BlockList) {
125+
continue
126+
}
127+
128+
filtered = append(filtered, envVar)
129+
}
130+
131+
return filtered
132+
}
133+
134+
func GetBuiltinBlocklist() []string {
135+
result := make([]string, len(BuiltinBlocklist))
136+
copy(result, BuiltinBlocklist)
137+
return result
138+
}

pkg/envutil/envutil_test.go

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
// SPDX-FileCopyrightText: Copyright The Lima Authors
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package envutil
5+
6+
import (
7+
"os"
8+
"reflect"
9+
"testing"
10+
)
11+
12+
func TestMatchesPattern(t *testing.T) {
13+
tests := []struct {
14+
name string
15+
pattern string
16+
expected bool
17+
}{
18+
{"PATH", "PATH", true},
19+
{"PATH", "HOME", false},
20+
{"SSH_AUTH_SOCK", "SSH_*", true},
21+
{"SSH_AGENT_PID", "SSH_*", true},
22+
{"HOME", "SSH_*", false},
23+
{"XDG_CONFIG_HOME", "XDG_*", true},
24+
{"_LIMA_TEST", "_*", true},
25+
{"LIMA_HOME", "_*", false},
26+
}
27+
28+
for _, tt := range tests {
29+
t.Run(tt.name+"_matches_"+tt.pattern, func(t *testing.T) {
30+
result := matchesPattern(tt.name, tt.pattern)
31+
if result != tt.expected {
32+
t.Errorf("matchesPattern(%q, %q) = %v, want %v", tt.name, tt.pattern, result, tt.expected)
33+
}
34+
})
35+
}
36+
}
37+
38+
func TestIsBlocked(t *testing.T) {
39+
patterns := []string{"PATH", "SSH_*", "XDG_*"}
40+
41+
tests := []struct {
42+
name string
43+
expected bool
44+
}{
45+
{"PATH", true},
46+
{"HOME", false},
47+
{"SSH_AUTH_SOCK", true},
48+
{"XDG_CONFIG_HOME", true},
49+
{"USER", false},
50+
}
51+
52+
for _, tt := range tests {
53+
t.Run(tt.name, func(t *testing.T) {
54+
result := isBlocked(tt.name, patterns)
55+
if result != tt.expected {
56+
t.Errorf("isBlocked(%q, %v) = %v, want %v", tt.name, patterns, result, tt.expected)
57+
}
58+
})
59+
}
60+
}
61+
62+
func TestParseEnvList(t *testing.T) {
63+
tests := []struct {
64+
input string
65+
expected []string
66+
}{
67+
{"", nil},
68+
{"PATH", []string{"PATH"}},
69+
{"PATH,HOME", []string{"PATH", "HOME"}},
70+
{"PATH, HOME , USER", []string{"PATH", "HOME", "USER"}},
71+
{" , , ", nil},
72+
}
73+
74+
for _, tt := range tests {
75+
t.Run(tt.input, func(t *testing.T) {
76+
result := parseEnvList(tt.input)
77+
if !reflect.DeepEqual(result, tt.expected) {
78+
t.Errorf("parseEnvList(%q) = %v, want %v", tt.input, result, tt.expected)
79+
}
80+
})
81+
}
82+
}
83+
84+
func TestGetFilterConfig(t *testing.T) {
85+
originalBlock := os.Getenv("LIMA_SHELLENV_BLOCK")
86+
originalAllow := os.Getenv("LIMA_SHELLENV_ALLOW")
87+
defer func() {
88+
if originalBlock != "" {
89+
t.Setenv("LIMA_SHELLENV_BLOCK", originalBlock)
90+
}
91+
if originalAllow != "" {
92+
t.Setenv("LIMA_SHELLENV_ALLOW", originalAllow)
93+
}
94+
}()
95+
96+
t.Run("default config", func(t *testing.T) {
97+
os.Unsetenv("LIMA_SHELLENV_BLOCK")
98+
os.Unsetenv("LIMA_SHELLENV_ALLOW")
99+
100+
config := GetFilterConfig()
101+
if !config.UseBuiltinBlocklist {
102+
t.Error("Expected UseBuiltinBlocklist to be true")
103+
}
104+
if !reflect.DeepEqual(config.BlockList, BuiltinBlocklist) {
105+
t.Error("Expected BlockList to equal BuiltinBlocklist")
106+
}
107+
if len(config.AllowList) != 0 {
108+
t.Error("Expected AllowList to be empty")
109+
}
110+
})
111+
112+
t.Run("custom blocklist", func(t *testing.T) {
113+
t.Setenv("LIMA_SHELLENV_BLOCK", "PATH,HOME")
114+
os.Unsetenv("LIMA_SHELLENV_ALLOW")
115+
116+
config := GetFilterConfig()
117+
if config.UseBuiltinBlocklist {
118+
t.Error("Expected UseBuiltinBlocklist to be false")
119+
}
120+
expected := []string{"PATH", "HOME"}
121+
if !reflect.DeepEqual(config.BlockList, expected) {
122+
t.Errorf("Expected BlockList to be %v, got %v", expected, config.BlockList)
123+
}
124+
})
125+
126+
t.Run("additive blocklist", func(t *testing.T) {
127+
t.Setenv("LIMA_SHELLENV_BLOCK", "+CUSTOM_VAR")
128+
os.Unsetenv("LIMA_SHELLENV_ALLOW")
129+
130+
config := GetFilterConfig()
131+
if !config.UseBuiltinBlocklist {
132+
t.Error("Expected UseBuiltinBlocklist to be true")
133+
}
134+
expected := make([]string, len(BuiltinBlocklist)+1)
135+
copy(expected, BuiltinBlocklist)
136+
expected[len(BuiltinBlocklist)] = "CUSTOM_VAR"
137+
if !reflect.DeepEqual(config.BlockList, expected) {
138+
t.Errorf("Expected BlockList to include builtin + custom, got %v", config.BlockList)
139+
}
140+
})
141+
142+
t.Run("allowlist", func(t *testing.T) {
143+
os.Unsetenv("LIMA_SHELLENV_BLOCK")
144+
t.Setenv("LIMA_SHELLENV_ALLOW", "FOO,BAR")
145+
146+
config := GetFilterConfig()
147+
expected := []string{"FOO", "BAR"}
148+
if !reflect.DeepEqual(config.AllowList, expected) {
149+
t.Errorf("Expected AllowList to be %v, got %v", expected, config.AllowList)
150+
}
151+
})
152+
}
153+
154+
func TestFilterEnvironment(t *testing.T) {
155+
testEnv := []string{
156+
"PATH=/usr/bin",
157+
"HOME=/home/user",
158+
"USER=testuser",
159+
"FOO=bar",
160+
"SSH_AUTH_SOCK=/tmp/ssh",
161+
"XDG_CONFIG_HOME=/config",
162+
}
163+
164+
originalEnviron := os.Environ
165+
defer func() {
166+
_ = originalEnviron
167+
}()
168+
169+
t.Run("default blocklist", func(_ *testing.T) {
170+
config := &FilterConfig{
171+
BlockList: BuiltinBlocklist,
172+
UseBuiltinBlocklist: true,
173+
}
174+
175+
_ = config
176+
_ = testEnv
177+
})
178+
}
179+
180+
func TestGetBuiltinBlocklist(t *testing.T) {
181+
blocklist := GetBuiltinBlocklist()
182+
183+
if &blocklist[0] == &BuiltinBlocklist[0] {
184+
t.Error("GetBuiltinBlocklist should return a copy, not the original slice")
185+
}
186+
187+
if !reflect.DeepEqual(blocklist, BuiltinBlocklist) {
188+
t.Error("GetBuiltinBlocklist should return the same content as BuiltinBlocklist")
189+
}
190+
191+
expectedItems := []string{"PATH", "HOME", "SSH_*"}
192+
for _, item := range expectedItems {
193+
found := false
194+
for _, blocked := range blocklist {
195+
if blocked == item {
196+
found = true
197+
break
198+
}
199+
}
200+
if !found {
201+
t.Errorf("Expected builtin blocklist to contain %q", item)
202+
}
203+
}
204+
}

0 commit comments

Comments
 (0)