Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto wandb tag with benchmark.py #308

Merged
merged 6 commits into from
Nov 2, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Auto wandb tag with benchmark.py
  • Loading branch information
vwxyzjn committed Oct 31, 2022
commit 961c40c3966b91aa615e16d270f999850ae1e7e7
32 changes: 32 additions & 0 deletions cleanrl_utils/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import argparse
import os
import shlex
import subprocess

import requests


def parse_args():
# fmt: off
Expand All @@ -16,6 +19,8 @@ def parse_args():
help="the number of random seeds")
parser.add_argument('--workers', type=int, default=0,
help='the number of eval workers to run benchmark experimenets (skips evaluation when set to 0)')
parser.add_argument("--auto-tag", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="if toggled, the runs will be tagged with the output from `git describe --tags` (e.g., v1.0.0b2-11-g5db4db7)")
args = parser.parse_args()
# fmt: on
return args
Expand All @@ -29,8 +34,35 @@ def run_experiment(command: str):
assert return_code == 0


def autotag() -> str:
git_tag = subprocess.check_output(["git", "describe", "--tags"]).decode("ascii").strip()
git_commit = subprocess.check_output(["git", "rev-parse", "--verify", "HEAD"]).decode("ascii").strip()

# try finding the pull request number on github
prs = requests.get(f"https://api.github.com/repos/vwxyzjn/cleanrl/commits/{git_commit}/pulls")
if prs.status_code == 200:
prs = prs.json()
if len(prs) > 0:
pr = prs[0]
pr_number = pr["number"]
pr["title"]
pr["html_url"]
wandb_tag = f"{git_tag},pr{pr_number}"
else:
wandb_tag = f"{git_tag}"
return wandb_tag


if __name__ == "__main__":
args = parse_args()
if args.auto_tag:
if "WANDB_TAGS" in os.environ:
raise ValueError(
"WANDB_TAGS is already set. Please unset it before running this script or run the script with --auto-tag False"
)
wandb_tag = autotag()
os.environ["WANDB_TAGS"] = wandb_tag

commands = []
for seed in range(1, args.num_seeds + 1):
for env_id in args.env_ids:
Expand Down