Skip to content

Commit 642c0ee

Browse files
michaeldwanCopilot
andauthored
Environment variables in cog.yaml (#2274)
* wip support for build+runtime environment in cog.yaml * support fast boot, cleaning * remove unused code * add env var integration test * fix config for integration test fixture * cog_binary fixture for configuring binary used in integration tests * remove unused Environment field on Build * updated denylist, more validation, more tests * tidy mod files after merge oddness * omit empty Environment config when writing yaml * typo Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Michael Dwan <m@dwan.io> * correct environment entry description --------- Signed-off-by: Michael Dwan <m@dwan.io> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 7b0d5f1 commit 642c0ee

File tree

14 files changed

+367
-7
lines changed

14 files changed

+367
-7
lines changed

go.mod

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ require (
2828
github.com/vincent-petithory/dataurl v1.0.0
2929
github.com/xeipuuv/gojsonschema v1.2.0
3030
github.com/xeonx/timeago v1.0.0-rc5
31-
golang.org/x/crypto v0.37.0
3231
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0
3332
golang.org/x/sync v0.13.0
3433
golang.org/x/sys v0.32.0
@@ -271,6 +270,7 @@ require (
271270
go.uber.org/automaxprocs v1.6.0 // indirect
272271
go.uber.org/multierr v1.11.0 // indirect
273272
go.uber.org/zap v1.27.0 // indirect
273+
golang.org/x/crypto v0.37.0 // indirect
274274
golang.org/x/exp/typeparams v0.0.0-20250210185358-939b2ce775ac // indirect
275275
golang.org/x/mod v0.24.0 // indirect
276276
golang.org/x/net v0.39.0 // indirect
@@ -280,7 +280,7 @@ require (
280280
google.golang.org/grpc v1.71.0 // indirect
281281
google.golang.org/protobuf v1.36.6 // indirect
282282
gopkg.in/yaml.v3 v3.0.1 // indirect
283-
gotest.tools/gotestsum v1.12.1 // indirect
283+
gotest.tools/gotestsum v1.12.2 // indirect
284284
honnef.co/go/tools v0.6.1 // indirect
285285
mvdan.cc/gofumpt v0.7.0 // indirect
286286
mvdan.cc/unparam v0.0.0-20240528143540-8a5130ca722f // indirect

go.sum

+2-2
Original file line numberDiff line numberDiff line change
@@ -764,8 +764,8 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
764764
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
765765
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
766766
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
767-
gotest.tools/gotestsum v1.12.1 h1:dvcxFBTFR1QsQmrCQa4k/vDXow9altdYz4CjdW+XeBE=
768-
gotest.tools/gotestsum v1.12.1/go.mod h1:mwDmLbx9DIvr09dnAoGgQPLaSXszNpXpWo2bsQge5BE=
767+
gotest.tools/gotestsum v1.12.2 h1:eli4tu9Q2D/ogDsEGSr8XfQfl7mT0JsGOG6DFtUiZ/Q=
768+
gotest.tools/gotestsum v1.12.2/go.mod h1:kjRtCglPZVsSU0hFHX3M5VWBM6Y63emHuB14ER1/sow=
769769
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
770770
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
771771
honnef.co/go/tools v0.6.1 h1:R094WgE8K4JirYjBaOpz/AvTyUu/3wbmAoskKN/pxTI=

pkg/config/config.go

+21
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ type Config struct {
7575
Predict string `json:"predict,omitempty" yaml:"predict"`
7676
Train string `json:"train,omitempty" yaml:"train,omitempty"`
7777
Concurrency *Concurrency `json:"concurrency,omitempty" yaml:"concurrency,omitempty"`
78+
Environment []string `json:"environment,omitempty" yaml:"environment,omitempty"`
79+
80+
parsedEnvironment map[string]string
7881
}
7982

8083
func DefaultConfig() *Config {
@@ -319,6 +322,11 @@ func (c *Config) ValidateAndComplete(projectDir string) error {
319322
}
320323
}
321324

325+
// parse and validate environment variables
326+
if err := c.loadEnvironment(); err != nil {
327+
errs = append(errs, err)
328+
}
329+
322330
if len(errs) > 0 {
323331
return errors.Join(errs...)
324332
}
@@ -577,3 +585,16 @@ func sliceContains(slice []string, s string) bool {
577585
}
578586
return false
579587
}
588+
589+
func (c *Config) ParsedEnvironment() map[string]string {
590+
return c.parsedEnvironment
591+
}
592+
593+
func (c *Config) loadEnvironment() error {
594+
env, err := parseAndValidateEnvironment(c.Environment)
595+
if err != nil {
596+
return err
597+
}
598+
c.parsedEnvironment = env
599+
return nil
600+
}

pkg/config/data/config_schema_v1.0.json

+14
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,20 @@
207207
"type": "object",
208208
"additionalItems": true
209209
}
210+
},
211+
"environment": {
212+
"$id": "#/properties/properties/environment",
213+
"type": [
214+
"array",
215+
"null"
216+
],
217+
"description": "A list of environment variables to make available during builds and at runtime, in the format `NAME=value`",
218+
"additionalItems": true,
219+
"items": {
220+
"$id": "#/properties/properties/environment/items",
221+
"type": "string",
222+
"pattern": "^[A-Za-z_][A-Za-z0-9_]*=[^\\s]+$"
223+
}
210224
}
211225
},
212226
"additionalProperties": false

pkg/config/env_variables.go

+76
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
package config
2+
3+
import (
4+
"fmt"
5+
"strings"
6+
)
7+
8+
// EnvironmentVariableDenyList is a list of environment variable patterns that are
9+
// used internally during build or runtime and thus not allowed to be set by the user.
10+
// There are ways around this restriction, but it's likely to cause unexpected behavior
11+
// and hard to debug issues. So on Cog's predict-build-push happy path, we don't allow
12+
// these to be set.
13+
// This list may change at any time. For more context, see:
14+
// https://github.com/replicate/cog/pull/2274/#issuecomment-2831823185
15+
var EnvironmentVariableDenyList = []string{
16+
// paths
17+
"PATH",
18+
"LD_LIBRARY_PATH",
19+
"PYTHONPATH",
20+
"VIRTUAL_ENV",
21+
"PYTHONUNBUFFERED",
22+
// Replicate
23+
"R8_*",
24+
"REPLICATE_*",
25+
// Nvidia
26+
"LIBRARY_PATH",
27+
"CUDA_*",
28+
"NVIDIA_*",
29+
"NV_*",
30+
// pget
31+
"PGET_*",
32+
"HF_ENDPOINT",
33+
"HF_HUB_ENABLE_HF_TRANSFER",
34+
// k8s
35+
"KUBERNETES_*",
36+
}
37+
38+
// validateEnvName checks if the given environment variable name is allowed.
39+
// Returns an error if the name matches any of the restricted patterns.
40+
func validateEnvName(name string) error {
41+
for _, pattern := range EnvironmentVariableDenyList {
42+
// Check for exact match
43+
if pattern == name {
44+
return fmt.Errorf("environment variable %q is not allowed", name)
45+
}
46+
47+
// Check for wildcard pattern
48+
if strings.HasSuffix(pattern, "*") {
49+
if strings.HasPrefix(name, pattern[:len(pattern)-1]) {
50+
return fmt.Errorf("environment variable %q is not allowed", name)
51+
}
52+
}
53+
}
54+
return nil
55+
}
56+
57+
// parseAndValidateEnvironment converts a slice of strings in the format of KEY=VALUE
58+
// to a map[string]string. An error is returned if the format is incorrect or if either
59+
// the variable name or value are invalid.
60+
func parseAndValidateEnvironment(input []string) (map[string]string, error) {
61+
env := map[string]string{}
62+
for _, input := range input {
63+
parts := strings.SplitN(input, "=", 2)
64+
if len(parts) != 2 || parts[0] == "" {
65+
return nil, fmt.Errorf("environment variable %q is not in the KEY=VALUE format", input)
66+
}
67+
if err := validateEnvName(parts[0]); err != nil {
68+
return nil, err
69+
}
70+
if _, ok := env[parts[0]]; ok {
71+
return nil, fmt.Errorf("environment variable %q is already defined", parts[0])
72+
}
73+
env[parts[0]] = parts[1]
74+
}
75+
return env, nil
76+
}

pkg/config/env_variables_test.go

+145
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
package config
2+
3+
import (
4+
"fmt"
5+
"strings"
6+
"testing"
7+
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
func TestEnvironmentConfig(t *testing.T) {
12+
t.Run("ParsingValidInput", func(t *testing.T) {
13+
cases := []struct {
14+
Name string
15+
Input []string
16+
Expected map[string]string
17+
}{
18+
{
19+
Name: "ValidInput",
20+
Input: []string{"NAME=VALUE"},
21+
Expected: map[string]string{"NAME": "VALUE"},
22+
},
23+
{
24+
Name: "ValidInputWithSpaces",
25+
Input: []string{"NAME=VALUE WITH SPACES"},
26+
Expected: map[string]string{"NAME": "VALUE WITH SPACES"},
27+
},
28+
{
29+
Name: "ValidInputWithQuotes",
30+
Input: []string{"NAME=\"VALUE WITH QUOTES\""},
31+
Expected: map[string]string{"NAME": `"VALUE WITH QUOTES"`},
32+
},
33+
{
34+
Name: "DelimitedValue",
35+
Input: []string{"NAME=VALUE1,VALUE2"},
36+
Expected: map[string]string{"NAME": "VALUE1,VALUE2"},
37+
},
38+
{
39+
Name: "EmptyValue",
40+
Input: []string{"NAME="},
41+
Expected: map[string]string{"NAME": ""},
42+
},
43+
{
44+
Name: "EmptyValueWithSpaces",
45+
Input: []string{"NAME= "},
46+
Expected: map[string]string{"NAME": " "},
47+
},
48+
{
49+
Name: "LowerCaseName",
50+
Input: []string{"name=VALUE"},
51+
Expected: map[string]string{"name": "VALUE"},
52+
},
53+
{
54+
Name: "MixedCaseName",
55+
Input: []string{"MiXeD_Case=VALUE"},
56+
Expected: map[string]string{"MiXeD_Case": "VALUE"},
57+
},
58+
{
59+
Name: "EqualSignInValue",
60+
Input: []string{"NAME=VALUE=EQUAL"},
61+
Expected: map[string]string{"NAME": "VALUE=EQUAL"},
62+
},
63+
{
64+
Name: "EqualSignInValueWithSpaces",
65+
Input: []string{"NAME=VALUE=EQUAL WITH SPACES"},
66+
Expected: map[string]string{"NAME": "VALUE=EQUAL WITH SPACES"},
67+
},
68+
{
69+
Name: "MultiLineValue",
70+
Input: []string{"NAME=VALUE1\nVALUE2"},
71+
Expected: map[string]string{"NAME": "VALUE1\nVALUE2"},
72+
},
73+
{
74+
Name: "MultiplePairs",
75+
Input: []string{"NAME1=VALUE1", "NAME2=VALUE2"},
76+
Expected: map[string]string{"NAME1": "VALUE1", "NAME2": "VALUE2"},
77+
},
78+
}
79+
80+
for _, c := range cases {
81+
t.Run(c.Name, func(t *testing.T) {
82+
parsed, err := parseAndValidateEnvironment(c.Input)
83+
require.NoError(t, err)
84+
require.Equal(t, c.Expected, parsed)
85+
})
86+
}
87+
})
88+
89+
t.Run("ParsingInvalidInput", func(t *testing.T) {
90+
cases := []struct {
91+
Name string
92+
Input []string
93+
ExpectedErrorMessage string
94+
}{
95+
{
96+
Name: "NameWithoutValue",
97+
Input: []string{"NAME"},
98+
ExpectedErrorMessage: `environment variable "NAME" is not in the KEY=VALUE format`,
99+
},
100+
{
101+
Name: "EmptyName",
102+
Input: []string{"=VALUE"},
103+
ExpectedErrorMessage: `environment variable "=VALUE" is not in the KEY=VALUE format`,
104+
},
105+
}
106+
107+
for _, c := range cases {
108+
t.Run(c.Name, func(t *testing.T) {
109+
_, err := parseAndValidateEnvironment(c.Input)
110+
require.Error(t, err)
111+
require.ErrorContains(t, err, c.ExpectedErrorMessage)
112+
})
113+
}
114+
})
115+
116+
t.Run("EnforceDenyList", func(t *testing.T) {
117+
for _, pattern := range EnvironmentVariableDenyList {
118+
// test that exact matches are rejected
119+
t.Run(fmt.Sprintf("Rejects %q", pattern), func(t *testing.T) {
120+
input := fmt.Sprintf("%s=VALUE", pattern)
121+
_, err := parseAndValidateEnvironment([]string{input})
122+
require.Error(t, err)
123+
require.ErrorContains(t, err, fmt.Sprintf("environment variable %q is not allowed", pattern))
124+
})
125+
126+
// test that prefix matches are rejected
127+
if strings.HasSuffix(pattern, "*") {
128+
t.Run(fmt.Sprintf("Rejects %q prefix", pattern), func(t *testing.T) {
129+
name := strings.TrimSuffix(pattern, "*") + "SUFFIX"
130+
input := fmt.Sprintf("%s=VALUE", name)
131+
_, err := parseAndValidateEnvironment([]string{input})
132+
require.Error(t, err)
133+
require.ErrorContains(t, err, fmt.Sprintf("environment variable %q is not allowed", name))
134+
})
135+
}
136+
}
137+
})
138+
139+
t.Run("DuplicateNamesAreRejected", func(t *testing.T) {
140+
input := []string{"NAME=VALUE", "NAME=VALUE2"}
141+
_, err := parseAndValidateEnvironment(input)
142+
require.Error(t, err)
143+
require.ErrorContains(t, err, "environment variable \"NAME\" is already defined")
144+
})
145+
}

pkg/dockerfile/env.go

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package dockerfile
2+
3+
import (
4+
"maps"
5+
"slices"
6+
7+
"github.com/replicate/cog/pkg/config"
8+
)
9+
10+
func envLineFromConfig(c *config.Config) (string, error) {
11+
vars := c.ParsedEnvironment()
12+
if len(vars) == 0 {
13+
return "", nil
14+
}
15+
16+
out := "ENV"
17+
for _, name := range slices.Sorted(maps.Keys(vars)) {
18+
out = out + " " + name + "=" + vars[name]
19+
}
20+
out += "\n"
21+
22+
return out, nil
23+
}

pkg/dockerfile/fast_generator.go

+8
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,14 @@ func (g *FastGenerator) installSrc(lines []string, weights []weights.Weight) ([]
431431
}
432432

433433
func (g *FastGenerator) entrypoint(lines []string) ([]string, error) {
434+
line, err := envLineFromConfig(g.Config)
435+
if err != nil {
436+
return nil, err
437+
}
438+
if line != "" {
439+
lines = append(lines, line)
440+
}
441+
434442
return append(lines, []string{
435443
"WORKDIR /src",
436444
"ENV VERBOSE=0",

pkg/dockerfile/standard_generator.go

+10
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,10 @@ func (g *StandardGenerator) GenerateInitialSteps(ctx context.Context) (string, e
143143
if err != nil {
144144
return "", err
145145
}
146+
envs, err := g.envVars()
147+
if err != nil {
148+
return "", err
149+
}
146150
runCommands, err := g.runCommands()
147151
if err != nil {
148152
return "", err
@@ -160,6 +164,7 @@ func (g *StandardGenerator) GenerateInitialSteps(ctx context.Context) (string, e
160164
steps := []string{
161165
"#syntax=docker/dockerfile:1.4",
162166
"FROM " + baseImage,
167+
envs,
163168
aptInstalls,
164169
installCog,
165170
pipInstalls,
@@ -177,6 +182,7 @@ func (g *StandardGenerator) GenerateInitialSteps(ctx context.Context) (string, e
177182
"FROM " + baseImage,
178183
g.preamble(),
179184
g.installTini(),
185+
envs,
180186
aptInstalls,
181187
installPython,
182188
pipInstalls,
@@ -505,6 +511,10 @@ This is the offending line: %s`, command)
505511
return strings.Join(lines, "\n"), nil
506512
}
507513

514+
func (g *StandardGenerator) envVars() (string, error) {
515+
return envLineFromConfig(g.Config)
516+
}
517+
508518
// writeTemp writes a temporary file that can be used as part of the build process
509519
// It returns the lines to add to Dockerfile to make it available and the filename it ends up as inside the container
510520
func (g *StandardGenerator) writeTemp(filename string, contents []byte) ([]string, string, error) {

0 commit comments

Comments
 (0)