Skip to content

Commit

Permalink
fix checkout without experiment path
Browse files Browse the repository at this point in the history
Signed-off-by: Andreas Jansson <andreas@replicate.ai>
  • Loading branch information
andreasjansson authored and bfirsh committed Mar 10, 2021
1 parent 3afae05 commit 202aa57
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 4 deletions.
92 changes: 92 additions & 0 deletions end-to-end-test/end_to_end_test/test_checkout.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,95 @@ def main():
"cicada.ogg",
]
assert set(actual_paths) == set(expected_paths)


def test_checkout_no_experiment_path(tmpdir, temp_bucket_factory, tmpdir_factory):
tmpdir = str(tmpdir)
repository = "file://" + str(tmpdir_factory.mktemp("repository"))

rand = str(random.randint(0, 100000))
os.mkdir(os.path.join(tmpdir, rand))
with open(os.path.join(tmpdir, rand, rand), "w") as f:
f.write(rand)

with open(os.path.join(tmpdir, "foo.txt"), "w") as f:
f.write("foo bar")

with open(os.path.join(tmpdir, "keepsake.yaml"), "w") as f:
f.write(
"""
repository: {repository}
""".format(
repository=repository
)
)
with open(os.path.join(tmpdir, "train.py"), "w") as f:
f.write(
"""
import os
import keepsake
def main():
experiment = keepsake.init()
experiment.checkpoint(path="foo.txt")
if __name__ == "__main__":
main()
"""
)

env = get_env()
cmd = ["python", "train.py"]
subprocess.run(cmd, cwd=tmpdir, env=env, check=True)

experiments = json.loads(
subprocess.run(
["keepsake", "ls", "--json"],
cwd=tmpdir,
env=env,
capture_output=True,
check=True,
).stdout
)
assert len(experiments) == 1

exp = experiments[0]

# checking out experiment
output_dir = str(tmpdir_factory.mktemp("output"))
subprocess.run(
["keepsake", "checkout", "-o", output_dir, exp["id"]],
cwd=tmpdir,
env=env,
check=True,
)

# Checkout out experiment checks out latest checkpoint
with open(os.path.join(output_dir, "foo.txt")) as f:
assert f.read() == "foo bar"

actual_paths = [
os.path.relpath(path, output_dir) for path in glob(output_dir + "/*")
]
expected_paths = ["foo.txt"]
assert set(actual_paths) == set(expected_paths)

# checking out checkpoint
latest_id = exp["latest_checkpoint"]["id"]

output_dir = str(tmpdir_factory.mktemp("output"))
subprocess.run(
["keepsake", "checkout", "-o", output_dir, latest_id],
cwd=tmpdir,
env=env,
check=True,
)

with open(os.path.join(output_dir, "foo.txt")) as f:
assert f.read() == "foo bar"

actual_paths = [
os.path.relpath(path, output_dir) for path in glob(output_dir + "/*")
]
expected_paths = ["foo.txt"]
assert set(actual_paths) == set(expected_paths)
31 changes: 27 additions & 4 deletions go/pkg/cli/checkout.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,36 @@ func overwriteDisplayPathPrompt(displayPath string, force bool) error {
}

if exists {
isEmpty, err := files.DirIsEmpty(displayPath)
isDir, err := files.IsDir(displayPath)
if err != nil {
return err
}
if !isEmpty && !force {
console.Warn("The directory %q is not empty.", displayPath)
console.Warn("%s Make sure they're saved in Git or Keepsake so they're safe!", aurora.Bold("This checkout may overwrite existing files."))
if isDir {
isEmpty, err := files.DirIsEmpty(displayPath)
if err != nil {
return err
}
if !isEmpty && !force {
console.Warn("The directory %q is not empty.", displayPath)
console.Warn("%s Make sure they're saved in Git or Keepsake so they're safe!", aurora.Bold("This checkout may overwrite existing files."))
fmt.Println()
// This is scary! See https://github.com/replicate/keepsake/issues/300
doOverwrite, err := console.InteractiveBool{
Prompt: "Do you want to continue?",
Default: false,
}.Read()
if err != nil {
return err
}
if !doOverwrite {
console.Info("Aborting.")
return nil
}
}
} else if !force {
// it's a file
console.Warn("The file %q exists.", displayPath)
console.Warn("%s Make sure it's saved in Git or Keepsake so it's safe!", aurora.Bold("This checkout may overwrite existing files."))
fmt.Println()
// This is scary! See https://github.com/replicate/keepsake/issues/300
doOverwrite, err := console.InteractiveBool{
Expand Down

0 comments on commit 202aa57

Please sign in to comment.