Skip to content

Commit 4c61f32

Browse files
authored
Recursive symlink copy procedure directory (#109)
* Recursive symlink copy procedure directory * Copy procedure directory into a temporary directory using symlinks * Add more asserts to the test case
1 parent 88ed85c commit 4c61f32

File tree

2 files changed

+38
-2
lines changed

2 files changed

+38
-2
lines changed

internal/tests/procedure_url_test.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@ func TestPrepareProcedureSourceURLLocal(t *testing.T) {
2626
srcDir := fmt.Sprintf("file://%s", fooDir)
2727
fooDst, err := util.PrepareProcedureSourceURL(srcDir)
2828
assert.NoError(t, err)
29-
assert.Equal(t, fooDir, fooDst)
29+
assert.DirExists(t, fooDst)
30+
assert.FileExists(t, filepath.Join(fooDst, "cog.yaml"))
31+
fooPy := filepath.Join(fooDst, "predict.py")
32+
assert.FileExists(t, fooPy)
33+
assert.Contains(t, string(must.Get(os.ReadFile(fooPy))), "'predicting foo'")
3034
}
3135

3236
func TestPrepareProcedureSourceURLRemote(t *testing.T) {

internal/util/procedure.go

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,37 @@ import (
55
"encoding/hex"
66
"fmt"
77
"io"
8+
"io/fs"
89
"net/http"
910
"net/url"
1011
"os"
1112
"os/exec"
1213
"path/filepath"
1314
)
1415

16+
func copyRecursiveSymlink(srcRoot, dstRoot string) error {
17+
return filepath.WalkDir(srcRoot, func(path string, d fs.DirEntry, err error) error {
18+
if err != nil {
19+
return err
20+
}
21+
22+
relPath, err := filepath.Rel(srcRoot, path)
23+
if err != nil {
24+
return err
25+
}
26+
27+
dstPath := filepath.Join(dstRoot, relPath)
28+
29+
if d.IsDir() {
30+
if path != srcRoot {
31+
return os.MkdirAll(dstPath, 0o755)
32+
}
33+
return nil
34+
}
35+
return os.Symlink(path, dstPath)
36+
})
37+
}
38+
1539
func PrepareProcedureSourceURL(srcURL string) (string, error) {
1640
u, err := url.Parse(srcURL)
1741
if err != nil {
@@ -26,7 +50,15 @@ func PrepareProcedureSourceURL(srcURL string) (string, error) {
2650
if !stat.IsDir() {
2751
return "", fmt.Errorf("invalid procedure source URL: %s", srcURL)
2852
}
29-
return u.Path, nil
53+
tmpDir, err := os.MkdirTemp("", "procedure*")
54+
if err != nil {
55+
return "", err
56+
}
57+
err = copyRecursiveSymlink(u.Path, tmpDir)
58+
if err != nil {
59+
return "", err
60+
}
61+
return tmpDir, nil
3062
} else if u.Scheme == "http" || u.Scheme == "https" {
3163
// http://host/path/to/tarball
3264
sha := sha256.New()

0 commit comments

Comments
 (0)